### Requirements

In [None]:
from dataset import*
from utility import*
from training import*
from baseline import*
from transformer import*
from torch.optim import lr_scheduler 

# Head pose DL model from https://github.com/thohemp/6drepnet
from sixdrepnet import SixDRepNet
import dlib

# Import models
from torchvision import models
from vit_pytorch.twins_svt import TwinsSVT # MEMORIA NON SUFFICIENTE RIPROVARLO
#from vit_pytorch.vit import ViT
from vit_pytorch.ats_vit import ViT
from vit_pytorch import SimpleViT
from vit_pytorch.crossformer import CrossFormer # MEMORIA NON SUFFICIENTE RIPROVARLO
from vit_pytorch.cross_vit import CrossViT

In [None]:
root_project = '/home/anto/University/Driving-Visual-Attention/'

In [None]:
print(f"We have {'' if torch.cuda.is_available() else 'not'} access to a GPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(torch.cuda.current_device())
    print(torch.cuda.device(0))
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))
print(device)

In [None]:
seed_everything(42)

##### Initialize pre-trained models for feature extraction

In [None]:
# Initialize face detector and facial landmarks predictor
predictor = dlib.shape_predictor("/home/anto/University/Driving-Visual-Attention/data/shape_predictor_68_face_landmarks.dat")
headpose_extractor = SixDRepNet()
face_detector = dlib.get_frontal_face_detector()

### Data Loader and Visualization

##### Files where to write the paths and labels

In [None]:
percentage = 100
save_train_file = root_project + 'save/save_train' + str(percentage)
save_val_file = root_project + 'save/save_val' + str(percentage)
save_test_file = root_project + 'save/save_test' + str(percentage)

##### Train Loader

In [None]:
train_dataset_classloader = DataLoaderVisualizer(root_project, save_train_file, percentage, predictor, face_detector, headpose_extractor, 'train')

##### Validtion Loader

In [None]:
val_dataset_classloader = DataLoaderVisualizer(root_project,save_val_file,percentage,predictor, face_detector, headpose_extractor,'val')

##### Test Loader

In [None]:
#test_dataset_classloader = DataLoaderVisualizer(root_project,save_test_file,percentage,predictor, face_detector, headpose_extractor,'test')

##### Visualization

In [None]:
#train_dataset_classloader.visualize_dataset()

In [None]:
#val_dataset_classloader.visualize_dataset()

In [None]:
#test_dataset_classloader.visualize_dataset()

### Pytorch Dataset 

In [None]:
# Choose size of the eyes
dim = (64,128)
# mean and std of images, calculated in advance
mean = (0.4570, 0.4422, 0.3900)
std = (0.2376, 0.2295, 0.2261)

my_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(dim, antialias=True),
    transforms.Normalize(mean=mean, std=mean, inplace=True)
])

In [None]:
train_dataset = DGAZEDataset('train','save/save_train'+str(percentage), my_transforms)
print(f'Train dataset len is {len(train_dataset)}')

In [None]:
# Print an example of the dataset for correct visualization
img_np = train_dataset[30][0].permute(1, 2, 0).numpy()
print(f"The bbox is: {train_dataset[30][3]}")
plt.imshow(img_np)
plt.axis('off')
plt.show()

In [None]:
val_dataset = DGAZEDataset('val','save/save_val'+str(percentage),my_transforms)
print(f'Val dataset len is {len(val_dataset)}')

In [None]:
#test_dataset = DGAZEDataset('test','save/save_test'+str(percentage),my_transforms)
#print(f'Test dataset len is {len(test_dataset)}')

### Vision Transformer Model

##### Hyerparameters

In [None]:
EPOCHS = 15
BATCH_SIZE = 32
THRESHOLD = 250
pre_trained = False

In [None]:
model = GazeCNN()
#model = CNNTrans()
model.to(device)

##### Criterion and Optimizer

In [None]:
bbox_accuracy_class = BBoxAccuracy()
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.95))
scheduler = lr_scheduler.StepLR(optimizer, step_size=20000, gamma=0.1)

##### Dataloader

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Training 

In [None]:
if pre_trained:
    ckpt_path = ''
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
wandb.login()
wandb.init(project="Baseline CNN", name=f"threshold={THRESHOLD} RGB batch_size{BATCH_SIZE} images with normalization {percentage} percent")

In [None]:
if pre_trained:
    start_epoch = checkpoint['epoch']
    EPOCHS = start_epoch + EPOCHS
else:
    start_epoch = 0
    EPOCHS = EPOCHS

for epoch in range(start_epoch, EPOCHS):
    # Training
    train_loss = train_epoch(model, train_loader, criterion, scheduler, optimizer, device, epoch)
    wandb.log({"epoch": epoch + 1,"train_loss": train_loss})

    # Validation
    val_loss, val_accuracy, bbox_accuracy, paper_accuracy = validate(model, bbox_accuracy_class , val_loader, THRESHOLD, criterion, device, epoch, BATCH_SIZE)
    wandb.log({"epoch": epoch + 1,"val_loss": val_loss})
    wandb.log({"epoch": epoch + 1,"accuracy_threshold": val_accuracy*100})
    wandb.log({"epoch": epoch + 1,"accuracy_bbox": bbox_accuracy*100})
    wandb.log({"epoch": epoch + 1,"accuracy_paper": paper_accuracy})

    #log_image(val_loader, model, device)

# Finish the WandB run
wandb.finish()

In [None]:
save_dict = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(save_dict, root_project + 'save/baseline_epochs' + str(EPOCHS)+ '.pth')

### Test