# MNAD Training

## Libraries import

In [None]:
import argparse
import os
import torch
import torch.utils.data as data
import torchvision.transforms as T
import torch.optim as optim
import sys
import json

from data.CustomDataset import CustomImageDataset
from data.CustomDataset import augment_dataset
from data.CustomDataset import show_augmented_dataset_info

## Constants

In [None]:
DATASET_DIR_SUFFIX = 'images'

## Parameters

In [None]:
args_dict = {
  "gpus": "1",                                    # gpus (set 1 or None)
  "train_batch_size": 4,                          # batch size for training
  "val_batch_size": 1,                            # batch size for validation
  "epochs": 60,                                   # number of epochs for training
  "loss_compact": 0.1,                            # weight of the feature compactness loss
  "loss_separate": 0.1,                           # weight of the feature separateness loss
  "h": 256,                                       # height of input images
  "w": 256,                                       # width of input images
  "c": 3,                                         # channel of input images
  "lr": 2e-4,                                     # initial learning rate
  "method": "recon",                              # The target task for anoamly detection (pred or recon)
  "t_length": 1,                                  # length of the frame sequences
  "fdim": 512,                                    # channel dimension of the features
  "mdim": 512,                                    # channel dimension of the memory items
  "msize": 10,                                    # number of the memory items
  "train_num_workers": 2,                         # number of workers for the train loader
  "val_num_workers": 1,                           # number of workers for the validation loader
  "dataset_type": "clean_road",                   # type of dataset: clean_road
  "dataset_path": "./dataset",                    # directory of data
  "label_path": "./dataset",                      # directory of labels
  "label_file": "metadata.csv",                   # name of the label file
  "exp_dir": "./log",                             # directory of log
  "split_dataset": True,                          # whether to split the dataset
  "val_label_file": "validation_labels.csv"       # name of the validation label file (used if split_dataset is False)
}

args = argparse.Namespace(**args_dict)

## GPU Configurations

In [None]:
print(torch.cuda.is_available())
if args.gpus is not None and torch.cuda.is_available():

  print(torch.cuda.device_count())

  print(torch.cuda.current_device())

  print(torch.cuda.device(0))

  print(torch.cuda.get_device_name(0))

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if args.gpus is None:
    gpus = "0"
    os.environ["CUDA_VISIBLE_DEVICES"]= gpus
else:
    gpus = ""
    for i in range(len(args.gpus)):
        gpus = gpus + args.gpus[i] + ","
    os.environ["CUDA_VISIBLE_DEVICES"]= gpus[:-1]

#torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance

## Data loading

In [None]:
data_folder = os.path.join(args.dataset_path, args.dataset_type, DATASET_DIR_SUFFIX)
data_label_file = os.path.join(args.label_path, args.dataset_type, args.label_file)

#transform = T.Resize((args.h,args.w))
transform = T.Compose([T.ToTensor(),])

# Create dataset
dataset = CustomImageDataset(data_label_file, data_folder, transform = transform, use_cv2=True)
dataset_size = len(dataset)

### Data splitting

In [None]:
if args.split_dataset:
    train_dataset, validation_dataset, _ = dataset.split_train_validation_test()
else:
    train_dataset = dataset
    validation_label_file = os.path.join(args.label_path, args.dataset_type, args.val_label_file)
    validation_dataset = CustomImageDataset(validation_label_file, data_folder, transform = transform, use_cv2=True)
train_size = len(train_dataset)
validation_size = len(validation_dataset)

### Data augmentation

In [None]:
# Create augmentation transform list
augmentation_transform_list = []

#### AutoAugment

In [None]:
enabled = False
if enabled:
    augmentation_transform = T.Compose([
        T.ToPILImage(),
        T.AutoAugment(),
        T.ToTensor(),    
    ])
    transform_name = "AutoAugment"
    applications_number = 3
    transform_dict = {"name": transform_name, "transform": augmentation_transform, "applications_number": applications_number}
    augmentation_transform_list.append(transform_dict)

#### RandAugment

In [None]:
enabled = False
if enabled:
    augmentation_transform = T.Compose([
        T.ToPILImage(),
        T.RandAugment(),
        T.ToTensor(),    
    ])
    transform_name = "RandAugment"
    applications_number = 3
    transform_dict = {"name": transform_name, "transform": augmentation_transform, "applications_number": applications_number}
    augmentation_transform_list.append(transform_dict)

#### AugMix

In [None]:
enabled = False
if enabled:
    augmentation_transform = T.Compose([
        T.ToPILImage(),
        T.AugMix(),
        T.ToTensor(),    
    ])
    transform_name = "AugMix"
    applications_number = 3
    transform_dict = {"name": transform_name, "transform": augmentation_transform, "applications_number": applications_number}
    augmentation_transform_list.append(transform_dict)

#### TrivialAgumentWide

In [None]:
enabled = False
if enabled:
    augmentation_transform = T.Compose([
        T.ToPILImage(),
        T.TrivialAugmentWide(),
        T.ToTensor(),    
    ])
    transform_name = "TrivialAugmentWide"
    applications_number = 3
    transform_dict = {"name": transform_name, "transform": augmentation_transform, "applications_number": applications_number}
    augmentation_transform_list.append(transform_dict)

#### Create augmented dataset

In [None]:
# Apply augment_dataset function to create augmented dataset
augmented_train_dataset = augment_dataset(train_dataset, augmentation_transform_list, create_dict=False)
augmented_train_size = len(augmented_train_dataset)

### Data batching

In [None]:
# Loading dataset
# Training
train_batch = data.DataLoader(augmented_train_dataset, train_batch_size = args.train_batch_size,
                              shuffle=True, train_num_workers=args.train_num_workers, drop_last=True)
train_batch_size = len(train_batch)
# Validation
validation_batch = data.DataLoader(validation_dataset, val_batch_size = args.val_batch_size,
                                   shuffle=True, val_num_workers=args.val_num_workers, drop_last=False)


### Show data info

In [None]:
show_augmented_dataset_info(augmented_train_dataset)

## Model setting

In [None]:
# Model setting
assert args.method == 'pred' or args.method == 'recon', 'Wrong task name'
if args.method == 'pred':
    from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
    model = convAE(args.c, args.t_length, args.msize, args.fdim, args.mdim)
else:
    from model.Reconstruction import *
    model = convAE(args.c, memory_size = args.msize, feature_dim = args.fdim, key_dim = args.mdim)
params_encoder =  list(model.encoder.parameters())
params_decoder = list(model.decoder.parameters())
params = params_encoder + params_decoder
optimizer = torch.optim.Adam(params, lr = args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max =args.epochs)

if args.gpus is not None and torch.cuda.is_available():
  model.cuda()

## Enable report

In [None]:
# Set and create (if necessary) the log directory
log_dir = os.path.join(args.exp_dir, args.dataset_type, args.method)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# Write the augmentation list to a json file
augmentations_file = open(os.path.join(log_dir, "augmentations.json"), "w")
# Remove the transform objects from the dictionary
for transform_dict in augmentation_transform_list:
    transform_dict.pop("transform")
# Write the json file
json.dump(augmentation_transform_list, augmentations_file)
augmentations_file.close()

# Set the log file
orig_stdout = sys.stdout
f = open(os.path.join(log_dir, 'log.txt'),'w')
sys.stdout= f

## Training

In [None]:
# Loss function
loss_func_mse = nn.MSELoss(reduction='none')
# Initialize the memory items
m_items = F.normalize(torch.rand((args.msize, args.mdim), dtype=torch.float), dim=1)
if args.gpus is not None and torch.cuda.is_available():
  m_items = m_items.cuda()

# Create pandas dataframe to store the results
import pandas as pd
results = pd.DataFrame(columns=['epoch', 'phase', 'loss', 'loss_pixel', 'loss_compactness', 'loss_separateness'])

# Training
for epoch in range(args.epochs):

    # Training phase
    model.train()

    for j,(images, labels) in enumerate(train_batch):

        if args.gpus is not None and torch.cuda.is_available():
          imgs = images["file"].cuda()

        if args.method == 'pred':
            outputs, _, _, m_items, softmax_score_query, softmax_score_memory, train_separateness_loss, train_compactness_loss = model.forward(imgs[:,0:12], m_items, True)
        else:
            outputs, _, _, m_items, softmax_score_query, softmax_score_memory, train_separateness_loss, train_compactness_loss = model.forward(imgs, m_items, True)

        optimizer.zero_grad()
        if args.method == 'pred':
            train_loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:]))
        else:
            train_loss_pixel = torch.mean(loss_func_mse(outputs, imgs))

        train_loss = train_loss_pixel + args.loss_compact * train_compactness_loss + args.loss_separate * train_separateness_loss
        train_loss.backward(retain_graph=True)
        optimizer.step()
    
    # Validation phase
    model.eval()

    for j, (images, labels) in enumerate(validation_batch):
        
        if args.gpus is not None and torch.cuda.is_available():
            imgs = images["file"].cuda()
        
        if args.method == 'pred':
            outputs, _, _, _, _, _, val_separateness_loss, val_compactness_loss = model.forward(imgs[:,0:12], m_items, False)
        else:
            outputs, _, _, _, _, _, val_separateness_loss, val_compactness_loss = model.forward(imgs, m_items, False)
        
        if args.method == 'pred':
            val_loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:]))
        else:
            val_loss_pixel = torch.mean(loss_func_mse(outputs, imgs))

        val_loss = val_loss_pixel + args.loss_compact * val_compactness_loss + args.loss_separate * val_separateness_loss

    scheduler.step()

    # Store the results
    results = results.append({'epoch': epoch, 'phase': 'train', 'loss': train_loss.item(), 'loss_pixel': train_loss_pixel.item(), 'loss_compactness': train_compactness_loss.item(), 'loss_separateness': train_separateness_loss.item()}, ignore_index=True)
    results = results.append({'epoch': epoch, 'phase': 'validation', 'loss': val_loss.item(), 'loss_pixel': val_loss_pixel.item(), 'loss_compactness': val_compactness_loss.item(), 'loss_separateness': val_separateness_loss.item()}, ignore_index=True)

    print('----------------------------------------')
    print('Epoch:', epoch+1)
    # Training results
    if args.method == 'pred':
        print('Train Loss: Prediction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}/ Total {:.6f}'.format(train_loss_pixel.item(), train_compactness_loss.item(), train_separateness_loss.item(), train_loss.item()))
    else:
        print('Train Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}/ Total {:.6f}'.format(train_loss_pixel.item(), train_compactness_loss.item(), train_separateness_loss.item(), train_loss.item()))
    # Validation results
    if args.method == 'pred':
        print('Validation Loss: Prediction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}/ Total {:.6f}'.format(val_loss_pixel.item(), val_compactness_loss.item(), val_separateness_loss.item(), val_loss.item()))
    else:
        print('Validation Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}/ Total {:.6f}'.format(val_loss_pixel.item(), val_compactness_loss.item(), val_separateness_loss.item(), val_loss.item()))
    print('Memory_items:')
    print(m_items)
    print('----------------------------------------')

print('Training is finished')

# Save the model and the memory items
torch.save(model, os.path.join(log_dir, 'model.pth'))
torch.save(m_items, os.path.join(log_dir, 'keys.pt'))

# Save the results
results.to_csv(os.path.join(log_dir, 'results.csv'))

sys.stdout = orig_stdout
f.close()