### Requirements

In [None]:
from dataset import*
from utility import*
from training import*
from baseline import*
from transformer import*

# Head pose DL model from https://github.com/thohemp/6drepnet
from sixdrepnet import SixDRepNet
import dlib

from torch.optim import lr_scheduler 
# Import models
from torchvision import models
from vit_pytorch.twins_svt import TwinsSVT # MEMORIA NON SUFFICIENTE RIPROVARLO
#from vit_pytorch.vit import ViT
from vit_pytorch.ats_vit import ViT
from vit_pytorch import SimpleViT
from vit_pytorch.crossformer import CrossFormer # MEMORIA NON SUFFICIENTE RIPROVARLO
from vit_pytorch.cross_vit import CrossViT

In [None]:
root_project = '/home/anto/University/Driving-Visual-Attention/'

In [None]:
print(f"We have {'' if torch.cuda.is_available() else 'not'} access to a GPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(torch.cuda.current_device())
    print(torch.cuda.device(0))
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))
print(device)

In [None]:
seed_everything(42)

##### Initialize pre-trained models for feature extraction

In [None]:
# Initialize face detector and facial landmarks predictor
predictor = dlib.shape_predictor("/home/anto/University/Driving-Visual-Attention/data/shape_predictor_68_face_landmarks.dat")
headpose_extractor = SixDRepNet()
face_detector = dlib.get_frontal_face_detector()

### Data Loader and Visualization

##### Files where to write the paths and labels

In [None]:
percentage = 100
save_train_file = root_project + 'save/save_train' + str(percentage)#+'_complete'
save_val_file = root_project + 'save/save_val' + str(percentage)#+'_complete'
save_test_file = root_project + 'save/save_test' + str(percentage)#+'_complete'

##### Train Validation and Test Loader

In [None]:
train_dataset_classloader = DataLoaderVisualizer(root_project, save_train_file, percentage, predictor, face_detector, headpose_extractor, 'train',big_file=False)
val_dataset_classloader = DataLoaderVisualizer(root_project,save_val_file,percentage,predictor, face_detector, headpose_extractor,'val',big_file=False)
test_dataset_classloader = DataLoaderVisualizer(root_project,save_val_file,percentage,predictor, face_detector, headpose_extractor,'test',big_file=False)

##### Visualization

In [None]:
#train_dataset_classloader.visualize_dataset()
#val_dataset_classloader.visualize_dataset()
#test_dataset_classloader.visualize_dataset()

### Pytorch Dataset 

In [None]:
# Choose size of the eyes
dim = (32,64)
# mean and std of images, calculated in advance
mean = (0.4570, 0.4422, 0.3900)
std = (0.2376, 0.2295, 0.2261)

my_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(dim, antialias=True),
    transforms.Normalize(mean=mean, std=mean, inplace=True)
])

In [None]:
train_dataset = DGAZEDataset('train', save_train_file, my_transforms, big_file=False)
print(f'Train dataset len is {len(train_dataset)}')

In [None]:
# Print an example of the dataset for correct visualization
img_np = train_dataset[30][0].permute(1, 2, 0).numpy()*255
print(f"The bbox is: {train_dataset[30][3]}")
print(f"The additional features are: {train_dataset[30][1]}")
plt.imshow(img_np)
plt.axis('off')
plt.show()

In [None]:
val_dataset = DGAZEDataset('val',save_val_file,my_transforms, big_file=False)
print(f'Val dataset len is {len(val_dataset)}')

In [None]:
#test_dataset = DGAZEDataset('test',save_test_file, my_transforms)
#print(f'Test dataset len is {len(test_dataset)}')

In [None]:
# Unite datasets, increse samples in the training or validation
#from torch.utils.data import ConcatDataset
#val_dataset = ConcatDataset([val_dataset,test_dataset])
#train_dataset = ConcatDataset([train_dataset,test_dataset])

### Vision Transformer Model

##### Hyerparameters

In [None]:
EPOCHS = 20
BATCH_SIZE = 16
THRESHOLD = 250
LR = 0.001
BETAS = (0.9, 0.97)
WEIGHT_DECAY = 1e-5
STEP_SIZE = 15000
GAMMA = 0.1
pre_trained = False

In [None]:
model = GazeCNN(additional_features_size=7)
#model = CNNTrans()
#tensor2 = torch.randn(64,3, 64, 128)
#tensor1 = torch.randn(64,7)
#out = model(tensor2,tensor1)
model.to(device)

##### Criterion and Optimizer

In [None]:
bbox_accuracy_class = BBoxAccuracy()
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, betas=BETAS, weight_decay= WEIGHT_DECAY)
scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

##### Dataloader

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Training 

In [None]:
if pre_trained:
    ckpt_path = ''
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
wandb.login()
wandb.init(project="Testing", name=f"threshold={THRESHOLD}, batch_size={BATCH_SIZE}, normalization,{percentage}=percent, weight_decay={WEIGHT_DECAY}, lr={LR},betas={BETAS}, gamma={GAMMA}, step_size={STEP_SIZE}")

In [None]:
if pre_trained:
    start_epoch = checkpoint['epoch']
    EPOCHS = start_epoch + EPOCHS
else:
    start_epoch = 0
    EPOCHS = EPOCHS

for epoch in range(start_epoch, EPOCHS):
    # Training
    train_loss = train_epoch(model, train_loader, criterion, scheduler, optimizer, device, epoch)
    wandb.log({"epoch": epoch + 1,"train_loss": train_loss})

    # Validation
    val_loss, val_accuracy, bbox_accuracy, paper_accuracy = validate(model, bbox_accuracy_class , val_loader, THRESHOLD, criterion, device, epoch, BATCH_SIZE)
    wandb.log({"epoch": epoch + 1,"val_loss": val_loss})
    wandb.log({"epoch": epoch + 1,"accuracy_threshold": val_accuracy*100})
    wandb.log({"epoch": epoch + 1,"accuracy_bbox": bbox_accuracy*100})
    wandb.log({"epoch": epoch + 1,"accuracy_paper(error)": paper_accuracy})

    #log_image(val_loader, model, device)

# Finish the WandB run
wandb.finish()

In [None]:
save_dict = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(save_dict, root_project + 'save/baseline_epochs' + str(EPOCHS)+'_'+str(THRESHOLD)+'.pth')