In [1]:
import sys
sys.path.append('..')
import torch
from torchmetrics.classification import BinaryF1Score, F1Score
import wandb
from tqdm import trange
import os
from datetime import datetime
from sklearn import metrics
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import pytorch_lightning as pl
import numpy as np

sys.path.insert(0, '/Users/evanpan/Documents/GitHub/EvansToolBox/Utils')
sys.path.insert(0, '/Users/evanpan/Desktop/openpose/python/')
sys.path.insert(0, '/scratch/ondemand27/evanpan/EvansToolBox/Utils/')
sys.path.insert(0, '/scratch/ondemand27/evanpan/Gaze_project/')

# from training.model import *
from Dataset_Util.dataloader import *

In [2]:
%load_ext autoreload
%autoreload 1

%aimport training.model
%aimport Dataset_Util.dataloader

# Loaded Things

In [3]:
dataset_location = "/scratch/ondemand27/evanpan/data/deep_learning_processed_dataset"
model_save_location = "/scratch/ondemand27/evanpan/data/Gaze_aversion_models"
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/baseline_config.json", "r"))

## Load dataset

In [4]:
# do the training test split here:
dataset_metadata = "/scratch/ondemand27/evanpan/data/deep_learning_processed_dataset/video_to_window_metadata.json"
dataset_metadata = json.load(open(dataset_metadata, "r"))
all_videos = list(dataset_metadata.keys())
training_set = []
testing_set = []
# get the name of the videos (this ensures no contamination because the same shot is split)
for i in range(0, len(all_videos)):
    if i / len(all_videos) < 0.9:
        training_set.append(all_videos[i])
    else:
        testing_set.append(all_videos[i])

In [5]:
class SentenceBaseline_GazePredictionModel(nn.Module):
    def __init__(self, config):
        torch.set_default_tensor_type(torch.DoubleTensor)
        # initialize model
        super(SentenceBaseline_GazePredictionModel, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.activation = nn.Sigmoid()
        self.num_layers = config["num_layers"]
        self.config = config
        # the feature of each speaker are encoded with a separate Linear Layer
        self.input_layer_self = nn.Linear(int(config["input_dims"]/2 - 6), config["input_layer_out"])
        self.input_layer_other = nn.Linear(int(config["input_dims"]/2 - 6), config["input_layer_out"])
        
        # the Recurrent Layer will take care of the next step
        self.lstm_hidden_dims = config["lstm_output_feature_size"]
        self.num_lstm_layer = config["lstm_layer_num"]
        self.frames_ahead = config["frames_ahead"]
        self.frames_behind = config["frames_behind"]
        self.lstm = nn.LSTM(2 * (config["input_layer_out"] + 6) * (self.frames_ahead + self.frames_behind + 1), 
                            self.lstm_hidden_dims, 
                            self.num_lstm_layer, 
                            batch_first=True)        
        # output layers
        self.output_layer_1 = nn.Linear(self.lstm_hidden_dims, config["output_layer_1_hidden"])
        self.output_layer_1 = nn.Sequential(self.output_layer_1, self.activation, nn.Dropout(self.config["dropout"]))
        self.output_layer_2 = nn.Linear(config["output_layer_1_hidden"], config["output_layer_2_hidden"])
        self.output_layer_2 = nn.Sequential(self.output_layer_2, self.activation, nn.Dropout(self.config["dropout"]))
        self.output_layer_3 = nn.Linear(config["output_layer_2_hidden"], config["output_layer_3_hidden"])
        self.output_layer_3 = nn.Sequential(self.output_layer_3)

        # audio_filler = torch.tensor([[[-36.04365338911715,0.0,0.0,0.0,0.0,0.0,-3.432169450445466e-14,0.0,0.0,0.0,9.64028691651994e-15,0.0,0.0,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715]]]).to(self.device)
        # text_filler = torch.ones([1, 1, 772]).to(self.device) * -15
        # text_filler[:, :, -4:] = 0
        # self.filler = torch.concat([audio_filler, text_filler], axis=2)
    def concate_frames(self, input_feature):
        # here I expect the 
        padding_front = torch.zeros([input_feature.shape[0], self.frames_ahead, input_feature.shape[2]]).to(self.device)
        padding_back = torch.zeros([input_feature.shape[0], self.frames_behind, input_feature.shape[2]]).to(self.device)
        padded_input_audio = torch.cat([padding_front, input_feature, padding_back], dim=1)
        window_audio = []
        for i in range(0, input_feature.shape[1]):
            window_count = i + self.frames_ahead
            current_window = padded_input_audio[:, window_count-self.frames_ahead:window_count+self.frames_behind+1]
            s = current_window.shape
            current_window = current_window.view((s[0], s[1] * s[2]))
            current_window = torch.unsqueeze(current_window, 1)
            window_audio.append(current_window)
        rtv = torch.cat(window_audio, dim=1)
        return rtv
    def forward(self, input_feature):
        feature_size = int(input_feature.size()[2] / 2)
        mod_audio_self = input_feature[:, :, :feature_size]
        mod_audio_other = input_feature[:, :, feature_size:]
        
        text_feature_self = mod_audio_self[:, :, :6]
        mod_audio_self = mod_audio_self[:, :, 6:]
        text_feature_other = mod_audio_self[:, :, :6]
        mod_audio_other = mod_audio_other[:, :, 6:]
        x1 = self.activation(self.input_layer_self(mod_audio_self))
        x2 = self.activation(self.input_layer_self(mod_audio_other))
        x1_windowed = self.concate_frames(x1)
        x2_windowed = self.concate_frames(x2)
        x_combined = torch.concat([x1_windowed, text_feature_self, x2_windowed, text_feature_other], axis=2)
        # here I'm assuming that the input_audio is of proper shape
        out, hidden_state = self.lstm(x_combined)
        # bn
        # x = self.bn(out.permute(0, 2, 1)).permute(0, 2, 1)
        x = self.activation(out)
        x = self.output_layer_1(x)
        x = self.output_layer_2(x)
        x = self.output_layer_3(x)
        return x
    def load_weights(self, pretrained_dict):
    #   not_copy = set(['fc.weight', 'fc.bias'])
        model_dict = self.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items()}
        model_dict.update(pretrained_dict)
        self.load_state_dict(model_dict)

In [6]:
class SimpleBaseline_GazePredictionModel(nn.Module):
    def __init__(self, config):
        torch.set_default_tensor_type(torch.DoubleTensor)
        # initialize model
        super(SimpleBaseline_GazePredictionModel, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.activation = nn.Sigmoid()
        self.num_layers = config["num_layers"]
        self.config = config
        # the feature of each speaker are encoded with a separate Linear Layer
        self.input_layer_self = nn.Linear(int(config["input_dims"]/2), config["input_layer_out"])
        self.input_layer_other = nn.Linear(int(config["input_dims"]/2), config["input_layer_out"])
        
        # the Recurrent Layer will take care of the next step
        self.lstm_hidden_dims = config["lstm_output_feature_size"]
        self.num_lstm_layer = config["lstm_layer_num"]
        self.frames_ahead = config["frames_ahead"]
        self.frames_behind = config["frames_behind"]
        self.lstm = nn.LSTM(2 * config["input_layer_out"] * (self.frames_ahead + self.frames_behind + 1), 
                            self.lstm_hidden_dims, 
                            self.num_lstm_layer, 
                            batch_first=True)
        self.lstm = nn.LSTM(2 * config["input_layer_out"] * (self.frames_ahead + self.frames_behind + 1), 
                            self.lstm_hidden_dims, 
                            self.num_lstm_layer, 
                            batch_first=True)
        
        # output layers
        self.output_layer_1 = nn.Linear(self.lstm_hidden_dims, config["output_layer_1_hidden"])
        self.output_layer_1 = nn.Sequential(self.output_layer_1, self.activation, nn.Dropout(self.config["dropout"]))
        self.output_layer_2 = nn.Linear(config["output_layer_1_hidden"], config["output_layer_2_hidden"])
        self.output_layer_2 = nn.Sequential(self.output_layer_2, self.activation, nn.Dropout(self.config["dropout"]))
        self.output_layer_3 = nn.Linear(config["output_layer_2_hidden"], config["output_layer_3_hidden"])
        self.output_layer_3 = nn.Sequential(self.output_layer_3, nn.Sigmoid())

        # audio_filler = torch.tensor([[[-36.04365338911715,0.0,0.0,0.0,0.0,0.0,-3.432169450445466e-14,0.0,0.0,0.0,9.64028691651994e-15,0.0,0.0,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715]]]).to(self.device)
        # text_filler = torch.ones([1, 1, 772]).to(self.device) * -15
        # text_filler[:, :, -4:] = 0
        # self.filler = torch.concat([audio_filler, text_filler], axis=2)
    def concate_frames(self, input_feature):
        # here I expect the 
        padding_front = torch.zeros([input_feature.shape[0], self.frames_ahead, input_feature.shape[2]]).to(self.device)
        padding_back = torch.zeros([input_feature.shape[0], self.frames_behind, input_feature.shape[2]]).to(self.device)
        padded_input_audio = torch.cat([padding_front, input_feature, padding_back], dim=1)
        window_audio = []
        for i in range(0, input_feature.shape[1]):
            window_count = i + self.frames_ahead
            current_window = padded_input_audio[:, window_count-self.frames_ahead:window_count+self.frames_behind+1]
            s = current_window.shape
            current_window = current_window.view((s[0], s[1] * s[2]))
            current_window = torch.unsqueeze(current_window, 1)
            window_audio.append(current_window)
        rtv = torch.cat(window_audio, dim=1)
        return rtv
    def forward(self, input_feature):
        feature_size = int(input_feature.size()[2] / 2)
        mod_audio_self = input_feature[:, :, :feature_size]
        mod_audio_other = input_feature[:, :, feature_size:]
        x1 = self.activation(self.input_layer_self(mod_audio_self))
        x2 = self.activation(self.input_layer_self(mod_audio_other))
        x1_windowed = self.concate_frames(x1)
        x2_windowed = self.concate_frames(x2)
        x_combined = torch.concat([x1_windowed, x2_windowed], axis=2)
        # here I'm assuming that the input_audio is of proper shape
        out, hidden_state = self.lstm(x_combined)
        # bn
        # x = self.bn(out.permute(0, 2, 1)).permute(0, 2, 1)
        x = self.activation(out)
        x = self.output_layer_1(x)
        x = self.output_layer_2(x)
        x = self.output_layer_3(x)
        return x

In [7]:
class SimpleBaselineTransformer_GazePredictionModel(nn.Module):
    def __init__(self, config):
        torch.set_default_tensor_type(torch.DoubleTensor)
        # initialize model
        super(SimpleBaselineTransformer_GazePredictionModel, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.activation = nn.Sigmoid()
        self.num_layers = config["num_layers"]
        self.config = config
        # the feature of each speaker are encoded with a separate Linear Layer
        self.input_layer_self = nn.Linear(int(config["input_dims"]/2), config["input_layer_out"])
        self.input_layer_other = nn.Linear(int(config["input_dims"]/2), config["input_layer_out"])
        
        # the Recurrent Layer will take care of the next step
        self.lstm_hidden_dims = config["lstm_output_feature_size"]
        self.num_lstm_layer = config["lstm_layer_num"]
        self.frames_ahead = config["frames_ahead"]
        self.frames_behind = config["frames_behind"]
        self.lstm = nn.LSTM(2 * config["input_layer_out"] * (self.frames_ahead + self.frames_behind + 1), 
                            self.lstm_hidden_dims, 
                            self.num_lstm_layer, 
                            batch_first=True)
        self.trasnformer = nn.Transformer(2 * config["input_layer_out"] * (self.frames_ahead + self.frames_behind + 1), batch_first=True,)
        
        # output layers
        self.output_layer_1 = nn.Linear(2 * config["input_layer_out"] * (self.frames_ahead + self.frames_behind + 1), config["output_layer_1_hidden"])
        self.output_layer_1 = nn.Sequential(self.output_layer_1, self.activation, nn.Dropout(self.config["dropout"]))
        self.output_layer_2 = nn.Linear(config["output_layer_1_hidden"], config["output_layer_2_hidden"])
        self.output_layer_2 = nn.Sequential(self.output_layer_2, self.activation, nn.Dropout(self.config["dropout"]))
        self.output_layer_3 = nn.Linear(config["output_layer_2_hidden"], config["output_layer_3_hidden"])
        self.output_layer_3 = nn.Sequential(self.output_layer_3, nn.Sigmoid())

        # audio_filler = torch.tensor([[[-36.04365338911715,0.0,0.0,0.0,0.0,0.0,-3.432169450445466e-14,0.0,0.0,0.0,9.64028691651994e-15,0.0,0.0,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715,-36.04365338911715]]]).to(self.device)
        # text_filler = torch.ones([1, 1, 772]).to(self.device) * -15
        # text_filler[:, :, -4:] = 0
        # self.filler = torch.concat([audio_filler, text_filler], axis=2)
    def concate_frames(self, input_feature):
        # here I expect the 
        padding_front = torch.zeros([input_feature.shape[0], self.frames_ahead, input_feature.shape[2]]).to(self.device)
        padding_back = torch.zeros([input_feature.shape[0], self.frames_behind, input_feature.shape[2]]).to(self.device)
        padded_input_audio = torch.cat([padding_front, input_feature, padding_back], dim=1)
        window_audio = []
        for i in range(0, input_feature.shape[1]):
            window_count = i + self.frames_ahead
            current_window = padded_input_audio[:, window_count-self.frames_ahead:window_count+self.frames_behind+1]
            s = current_window.shape
            current_window = current_window.view((s[0], s[1] * s[2]))
            current_window = torch.unsqueeze(current_window, 1)
            window_audio.append(current_window)
        rtv = torch.cat(window_audio, dim=1)
        return rtv
    def forward(self, input_feature):
        feature_size = int(input_feature.size()[2] / 2)
        mod_audio_self = input_feature[:, :, :feature_size]
        mod_audio_other = input_feature[:, :, feature_size:]
        x1 = self.activation(self.input_layer_self(mod_audio_self))
        x2 = self.activation(self.input_layer_self(mod_audio_other))
        x1_windowed = self.concate_frames(x1)
        x2_windowed = self.concate_frames(x2)
        x_combined = torch.concat([x1_windowed, x2_windowed], axis=2)
        # here I'm assuming that the input_audio is of proper shape
        out, hidden_state = self.lstm(x_combined)
        # bn
        # x = self.bn(out.permute(0, 2, 1)).permute(0, 2, 1)
        x = self.activation(out)
        x = self.output_layer_1(x)
        x = self.output_layer_2(x)
        x = self.output_layer_3(x)
        return x

# Training Loop - For baseline model

In [10]:
def train_model(model, config, train_data, valid_data, wandb, model_name):
    optimiser = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.train() 
    loss_fn = nn.CrossEntropyLoss()
    training_loss = []
    valid_loss = []
    training_f1 = []
    valid_f1 = []
    aversion_vs_start = []
    count = 0
    # f1_score = BinaryF1Score(num_classes=2).to(device)
    f1_score = F1Score(task="multiclass", num_classes=2, average="weighted").to(device)
    for epoch in range(1, config['epochs'] + 1):
        total_train_loss = 0
        total_valid_loss = 0
        total_aversion_predicted = 0
        total_train_f1 = 0
        total_valid_f1 = 0
        train_batch_counter = 0
        valid_batch_counter = 0
        total_prediction_counter = 0
        prediction_mean = 0
        prediction_std = 0
        model.zero_grad()
        for _, (X, Y) in enumerate(train_data):
            train_batch_counter += 1
            X, Y = X.to(device), Y.to(device)
            optimiser.zero_grad()
            if "Transformer" in config["model_type"]:
                all_zero = torch.zeros(Y.shape).to(device)
                pred = model(X, all_zero)
            else:
                pred = model(X)
            loss = loss_fn(pred.transpose(2, 1), Y.long())
            loss.backward()
            optimiser.step()
            total_train_loss += loss.item()
            # binary_pred = torch.round(pred)
            binary_pred = torch.argmax(pred, axis=2, keepdim=True)
            prediction_mean = torch.mean(binary_pred.float()).item()
            prediction_std = torch.std(binary_pred.float()).item()            
            f1_train = f1_score(binary_pred, torch.unsqueeze(Y, axis=2)).item()
            total_aversion_predicted += torch.sum(binary_pred).item()
            total_prediction_counter += binary_pred.size()[0] * binary_pred.size()[1] 
            total_train_f1 += f1_train
            del X, Y, pred
            torch.cuda.empty_cache()

        total_train_f1 /= train_batch_counter
        total_train_loss /= len(train_data)
        total_aversion_predicted /= total_prediction_counter

        for _, (X, Y) in enumerate(valid_data):
            with torch.no_grad():
                valid_batch_counter += 1
                X, Y = X.to(device), Y.to(device)
                if "Transformer" in config["model_type"]:
                    all_zero = torch.zeros(Y.shape).to(device)
                    pred = model(X, all_zero)
                else:
                    pred = model(X)
                loss = loss_fn(pred.transpose(2, 1), Y.long())
                total_valid_loss += loss.item()

                # binary_pred = torch.round(pred)
                binary_pred = torch.argmax(pred, axis=2, keepdim=True)
                f1_valid = f1_score(binary_pred, torch.unsqueeze(Y, axis=2)).item()
                total_valid_f1 += f1_valid
                del X, Y, pred
                torch.cuda.empty_cache()

        total_valid_f1 /= valid_batch_counter
        total_valid_loss /= len(valid_data)

        if config['wandb']:
            wandb.log({'training loss': total_train_loss,
                        'validation_loss': total_valid_loss,
                        'training_f1': total_train_f1,
                        'validation_f1': total_valid_f1, 
                        "percentage_predicted_aversion": total_aversion_predicted})
        training_loss.append(total_train_loss)
        valid_loss.append(total_valid_loss)
        training_f1.append(total_train_f1)
        valid_f1.append(total_valid_f1)
        aversion_vs_start.append(total_aversion_predicted)
        if total_valid_f1 == max(valid_f1):
            try:
                os.mkdir(os.path.join(*[model_save_location, model_name]))
            except:
                pass
            config_save_path = os.path.join(*[model_save_location, model_name, "config.json"])
            json.dump(config, open(config_save_path, "w"))
            file_name = f'time={datetime.now()}_epoch={epoch}.pt'
            save_path = os.path.join(*[model_save_location, model_name, file_name])
            torch.save(model.state_dict(), save_path)
        if config['early_stopping']>0:
            if epoch > 1:
                if total_valid_f1 < np.mean(valid_f1[epoch - 7:epoch - 2]):
                    count += 1
                else:
                    count = 0
            if count >= config['early_stopping']:
                print('\n\nStopping early due to decrease in performance on validation set\n\n')
                break 
        if count == 0:
            print("Epoch {}, mean: {}, std: {}\ntraining L: {}\nvalidation L:{}".format(epoch, prediction_mean, prediction_std, total_train_f1, total_valid_f1))
        else:
            print("Epoch {}, mean: {}, std: {}\ntraining L: {}\nvalidation L:{}, model have not improved for {} iterations".format(epoch, prediction_mean, prediction_std, total_train_f1, total_valid_f1, count))
    if config['wandb']:
        save_path = os.path.join(*[model_save_location, model_name, file_name])
        wandb.save(save_path)

## Weight and Biases Stuff

## Train the model

In [57]:
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/baseline_config.json", "r"))
# obtain the dataset
torch.set_default_tensor_type(torch.DoubleTensor)
training_dataset = Aversion_SelfTap111(dataset_location, training_set[:10])
validation_dataset = Aversion_SelfTap111(dataset_location, testing_set[:2])
train_dataloader = torch.utils.data.DataLoader(training_dataset, config['batch_size'], True)
valid_dataloader = torch.utils.data.DataLoader(validation_dataset, config['batch_size'], True)
model = SimpleBaseline_GazePredictionModel(config)
train_model(model, config, train_dataloader, valid_dataloader, run_obj, model_save_location)
run_obj.finish()

  0%|          | 0/200 [00:00<?, ?it/s]

0.03449014124949654


  0%|          | 1/200 [00:02<07:47,  2.35s/it]

mean: 0.5974612089529724, std: 0.03449014124949654
training L: 0.7498041014467824
validation L:0.7801158889905458
0.0387789817215484


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  1%|          | 2/200 [00:04<06:59,  2.12s/it]

mean: 0.648003122124935, std: 0.0387789817215484
training L: 0.750510350090623
validation L:0.7802088892277198
0.029644162000120988


  2%|▏         | 3/200 [00:06<06:41,  2.04s/it]

mean: 0.6950838980002896, std: 0.029644162000120988
training L: 0.7506199214130393
validation L:0.7802088892277198
0.02540355121551124


  2%|▏         | 4/200 [00:08<06:31,  2.00s/it]

mean: 0.6974240211738563, std: 0.02540355121551124
training L: 0.7506199214130393
validation L:0.7802088892277198
0.014475230401421608


  2%|▎         | 5/200 [00:10<06:36,  2.03s/it]

mean: 0.6587029522108311, std: 0.014475230401421608
training L: 0.7506199214130393
validation L:0.7802088892277198
0.008366695740317093


  3%|▎         | 6/200 [00:12<06:28,  2.00s/it]

mean: 0.6412970440137612, std: 0.008366695740317093
training L: 0.7506199214130393
validation L:0.7802088892277198


  3%|▎         | 6/200 [00:12<06:37,  2.05s/it]


KeyboardInterrupt: 

# Training Loop - For Audio Only model

In [None]:
run_obj = wandb.init(project="gaze_prediction", config=config, settings=wandb.Settings(start_method="fork"))

In [20]:
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/audio_only_config.json", "r"))
# obtain the dataset
torch.set_default_tensor_type(torch.DoubleTensor)
training_dataset = Aversion_SelfTap111(dataset_location, training_set[:12], audio_only=True)
validation_dataset = Aversion_SelfTap111(dataset_location, testing_set[:2], audio_only=True)
train_dataloader = torch.utils.data.DataLoader(training_dataset, config['batch_size'], True)
valid_dataloader = torch.utils.data.DataLoader(validation_dataset, config['batch_size'], True)
model = SimpleBaseline_GazePredictionModel(config)
train_model(model, config, train_dataloader, valid_dataloader, run_obj, model_save_location)

  0%|          | 1/200 [00:01<05:35,  1.69s/it]

mean: 0.5375652030339769, std: 0.038422848050369815
training L: 0.7323386721605636
validation L:0.7280215162212137


  1%|          | 2/200 [00:03<05:08,  1.56s/it]

mean: 0.5396894879007264, std: 0.03841229046891618
training L: 0.7376692695463082
validation L:0.7348674337168585


  2%|▏         | 3/200 [00:04<04:59,  1.52s/it]

mean: 0.5418414783457995, std: 0.0382727172988321
training L: 0.7430556688251317
validation L:0.7396762471093492


  2%|▏         | 4/200 [00:05<04:43,  1.45s/it]

mean: 0.5443468853526254, std: 0.03826918974963933
training L: 0.7486199438292928
validation L:0.7408316961362148


  2%|▎         | 5/200 [00:07<04:44,  1.46s/it]

mean: 0.5468758275736018, std: 0.03847715916641862
training L: 0.751742524009893
validation L:0.7457544486877387


  3%|▎         | 6/200 [00:08<04:34,  1.42s/it]

mean: 0.5487699903925971, std: 0.03846862209941646
training L: 0.7545170146204431
validation L:0.7470093760103459


  4%|▎         | 7/200 [00:10<04:38,  1.44s/it]

mean: 0.5512063178953902, std: 0.03835452302643252
training L: 0.7570841856556142
validation L:0.7537631811961684


  4%|▍         | 8/200 [00:11<04:39,  1.45s/it]

mean: 0.5532094666883071, std: 0.038496512912125684
training L: 0.7617156097791649
validation L:0.7549647661755285


  4%|▍         | 9/200 [00:13<04:39,  1.47s/it]

mean: 0.5554321769720268, std: 0.03805370844533758
training L: 0.7656895903795565
validation L:0.7606082318286761


  5%|▌         | 10/200 [00:14<04:38,  1.47s/it]

mean: 0.5574337472479528, std: 0.03852652032566119
training L: 0.7645938095386826
validation L:0.7613699280234122


  6%|▌         | 11/200 [00:16<04:29,  1.42s/it]

mean: 0.559981626833095, std: 0.038347647409568156
training L: 0.7698066319794983
validation L:0.7640094711917916


  6%|▌         | 12/200 [00:17<04:30,  1.44s/it]

mean: 0.5622004870390143, std: 0.038257843264804414
training L: 0.7722952821004156
validation L:0.7646226415094339


  6%|▋         | 13/200 [00:18<04:22,  1.40s/it]

mean: 0.5640483085938395, std: 0.038382897538683555
training L: 0.7718415896573535
validation L:0.7659007136695161


  7%|▋         | 14/200 [00:20<04:16,  1.38s/it]

mean: 0.5665515094659646, std: 0.03814504706038931
training L: 0.7751711855559309
validation L:0.7667606516290727


  8%|▊         | 15/200 [00:21<04:12,  1.36s/it]

mean: 0.5685410694377985, std: 0.03836502495453606
training L: 0.7761418131118908
validation L:0.767887763055339


  8%|▊         | 16/200 [00:22<04:17,  1.40s/it]

mean: 0.5710678310108791, std: 0.038261168945831855
training L: 0.7776878065176539
validation L:0.7699121920895174


  8%|▊         | 17/200 [00:24<04:20,  1.43s/it]

mean: 0.5729987377715186, std: 0.038263397836928795
training L: 0.7794270039622067
validation L:0.7713554193498332


  9%|▉         | 18/200 [00:25<04:22,  1.44s/it]

mean: 0.575360365072677, std: 0.03808680008983807
training L: 0.7809720723049091
validation L:0.7731926776295377


 10%|▉         | 19/200 [00:27<04:15,  1.41s/it]

mean: 0.5771138850246609, std: 0.038405858208403906
training L: 0.7801478737455783
validation L:0.7712205700123915


 10%|█         | 20/200 [00:28<04:18,  1.44s/it]

mean: 0.5790664975047349, std: 0.03818548580533589
training L: 0.7813007020576506
validation L:0.7748400524165575


 10%|█         | 21/200 [00:30<04:19,  1.45s/it]

mean: 0.5807451319438747, std: 0.038410857081182706
training L: 0.7829657011271361
validation L:0.775189014041043


 11%|█         | 22/200 [00:31<04:11,  1.41s/it]

mean: 0.582903799551212, std: 0.03838164564712754
training L: 0.782283917014456
validation L:0.7741985203452528


 12%|█▏        | 23/200 [00:32<04:05,  1.39s/it]

mean: 0.5850501231411687, std: 0.03800898681254578
training L: 0.7848788953444503
validation L:0.7764959237040455


 12%|█▏        | 24/200 [00:34<04:08,  1.41s/it]

mean: 0.5871901499704449, std: 0.03825333668878838
training L: 0.7846295374685186
validation L:0.7776071264014744


 12%|█▎        | 25/200 [00:35<04:02,  1.39s/it]

mean: 0.5890066738191944, std: 0.03832396503706255
training L: 0.7850146176798577
validation L:0.7762248502534173


 13%|█▎        | 26/200 [00:37<03:57,  1.37s/it]

mean: 0.5907227826742486, std: 0.03842968105761962
training L: 0.7855368728455089
validation L:0.7766096232062006


 14%|█▎        | 27/200 [00:38<03:54,  1.35s/it]

mean: 0.5928269195039755, std: 0.038216674116663824
training L: 0.7860214730384049
validation L:0.778169014084507


 14%|█▍        | 28/200 [00:39<03:50,  1.34s/it]

mean: 0.5951770582878034, std: 0.0382402461495092
training L: 0.786470234515935
validation L:0.7770622508432996


 14%|█▍        | 29/200 [00:40<03:46,  1.33s/it]

mean: 0.5965046600255317, std: 0.038272955835638604
training L: 0.7863828252933311
validation L:0.7768582375478927


 15%|█▌        | 30/200 [00:42<03:45,  1.33s/it]

mean: 0.5982373753965138, std: 0.038192511542916426
training L: 0.7866228168971783
validation L:0.7777267156862745


 16%|█▌        | 31/200 [00:43<03:52,  1.38s/it]

mean: 0.5998975201858427, std: 0.03825841699566252
training L: 0.7868739309185847
validation L:0.7783386874713171


 16%|█▌        | 32/200 [00:45<03:48,  1.36s/it]

mean: 0.6019111391163703, std: 0.038308990823368284
training L: 0.7869959193470956
validation L:0.7779735008041664


 16%|█▋        | 33/200 [00:46<03:52,  1.39s/it]

mean: 0.6036713508981063, std: 0.03826371443377106
training L: 0.7872043671920694
validation L:0.7781857121003518


 17%|█▋        | 34/200 [00:47<03:48,  1.38s/it]

mean: 0.6051162892788426, std: 0.03837444523661545
training L: 0.7873779564167555
validation L:0.7789152024446142


 18%|█▊        | 35/200 [00:49<03:44,  1.36s/it]

mean: 0.6067495495697391, std: 0.03793514279397345
training L: 0.7877943138048349
validation L:0.7784742394129338


 18%|█▊        | 36/200 [00:50<03:42,  1.36s/it]

mean: 0.6085021555963259, std: 0.03827496985733568
training L: 0.7881376251273752
validation L:0.7788652699189479


 18%|█▊        | 37/200 [00:52<03:48,  1.40s/it]

mean: 0.6098727395617883, std: 0.038379580377216525
training L: 0.7877607288420043
validation L:0.7788138184041578


 19%|█▉        | 38/200 [00:53<03:43,  1.38s/it]

mean: 0.6117281921448884, std: 0.03800839217986361
training L: 0.7881538738090959
validation L:0.7788219115287646


 20%|█▉        | 39/200 [00:54<03:39,  1.36s/it]

mean: 0.6132555046716315, std: 0.03813427039511276
training L: 0.7882763475161373
validation L:0.78


 20%|██        | 40/200 [00:56<03:36,  1.35s/it]

mean: 0.6146646101171072, std: 0.03815154347277368
training L: 0.7881655387195304
validation L:0.7792108677402122


 20%|██        | 41/200 [00:57<03:42,  1.40s/it]

mean: 0.6157974883200633, std: 0.03834291259328531
training L: 0.7884848575577479
validation L:0.7794488128864799


 21%|██        | 42/200 [00:58<03:37,  1.38s/it]

mean: 0.6175480401993695, std: 0.0380655845445438
training L: 0.78839590443686
validation L:0.7795678399633504


 22%|██▏       | 43/200 [01:00<03:33,  1.36s/it]

mean: 0.6189577184437718, std: 0.03818574687252487
training L: 0.7886725239855712
validation L:0.7797955135052648


 22%|██▏       | 44/200 [01:01<03:38,  1.40s/it]

mean: 0.6200836475331568, std: 0.03797256523459504
training L: 0.7887201735357917
validation L:0.7800747577999847


 22%|██▎       | 45/200 [01:03<03:33,  1.38s/it]

mean: 0.6214130102648945, std: 0.03796793105375696
training L: 0.7886042614316495
validation L:0.7794566544566545


 23%|██▎       | 46/200 [01:04<03:30,  1.36s/it]

mean: 0.62247247622866, std: 0.03803493969337708
training L: 0.7886333368249361
validation L:0.7800411993591211


 24%|██▎       | 47/200 [01:05<03:27,  1.36s/it]

mean: 0.6238084493179571, std: 0.037875842012672904
training L: 0.7886422117168332
validation L:0.7798809342085178


 24%|██▍       | 48/200 [01:07<03:26,  1.36s/it]

mean: 0.6248945816924095, std: 0.03827418552441704
training L: 0.7889847875188847
validation L:0.7793971766501335


 24%|██▍       | 49/200 [01:08<03:23,  1.35s/it]

mean: 0.6261484635792381, std: 0.03790382461395846
training L: 0.7887989708456119
validation L:0.7801007172287502


 25%|██▌       | 50/200 [01:09<03:21,  1.34s/it]

mean: 0.6274820380729639, std: 0.03795289443161719
training L: 0.7888533221647172
validation L:0.7798368031724243


 26%|██▌       | 51/200 [01:11<03:19,  1.34s/it]

mean: 0.6287918735569356, std: 0.038010754836282835
training L: 0.7889241263762565
validation L:0.779769623922496


 26%|██▌       | 52/200 [01:12<03:18,  1.34s/it]

mean: 0.6298811093429626, std: 0.037993665419687114
training L: 0.7888297235850186
validation L:0.7800152555301296


 26%|██▋       | 53/200 [01:13<03:16,  1.34s/it]

mean: 0.6308418167201261, std: 0.03790918813462353
training L: 0.7888596386263013
validation L:0.779929845966143


 27%|██▋       | 54/200 [01:15<03:15,  1.34s/it]

mean: 0.6317639670979591, std: 0.03772842298266372
training L: 0.7888322292174486
validation L:0.7796170569837516


 28%|██▊       | 55/200 [01:16<03:14,  1.34s/it]

mean: 0.6328735757656325, std: 0.03798885606805677
training L: 0.7888549892318737
validation L:0.7797437461866992


 28%|██▊       | 56/200 [01:17<03:12,  1.34s/it]

mean: 0.6336673558270467, std: 0.03800663855394427
training L: 0.7889274391338159
validation L:0.779769623922496


 28%|██▊       | 57/200 [01:19<03:10,  1.33s/it]

mean: 0.6345197747454334, std: 0.037944321430231465
training L: 0.7889455502549761
validation L:0.7800823547353973


 29%|██▉       | 58/200 [01:20<03:16,  1.38s/it]

mean: 0.6357832789447202, std: 0.03807899012996211
training L: 0.7889392375910389
validation L:0.779929845966143


 30%|██▉       | 59/200 [01:21<03:13,  1.37s/it]

mean: 0.6364019636121019, std: 0.03815005899083669
training L: 0.7889935696126813
validation L:0.7804133933338419


 30%|███       | 60/200 [01:23<03:17,  1.41s/it]

mean: 0.6374840290487794, std: 0.03800236525374677
training L: 0.7889219542103453
validation L:0.7803538743136058


 30%|███       | 61/200 [01:24<03:12,  1.38s/it]

mean: 0.637778766308804, std: 0.0378343596086585
training L: 0.7888865627897239
validation L:0.779929845966143


 31%|███       | 62/200 [01:26<03:08,  1.37s/it]

mean: 0.6388077489404568, std: 0.03789524239582371
training L: 0.7890487155716379
validation L:0.7797437461866992


 32%|███▏      | 63/200 [01:27<03:06,  1.36s/it]

mean: 0.6392297369299728, std: 0.037880109618109344
training L: 0.7889707971350389
validation L:0.7801753717117804


 32%|███▏      | 64/200 [01:28<03:04,  1.36s/it]

mean: 0.6401475170824421, std: 0.03775352116254244
training L: 0.7890558421170666
validation L:0.7798032186713447


 32%|███▎      | 65/200 [01:30<03:02,  1.35s/it]

mean: 0.6408517402368207, std: 0.03767392863007117
training L: 0.7889589999700948
validation L:0.7802943643712347


 33%|███▎      | 66/200 [01:31<03:00,  1.34s/it]

mean: 0.6415438765502811, std: 0.0378771443774228
training L: 0.7889526885579281
validation L:0.7802348635046515


 34%|███▎      | 67/200 [01:32<02:58,  1.34s/it]

mean: 0.6418545848710732, std: 0.03785251692680205
training L: 0.7888078004426632
validation L:0.7796506750057204


 34%|███▍      | 68/200 [01:34<02:56,  1.34s/it]

mean: 0.6430145007766597, std: 0.03793944414056601
training L: 0.7890007027827205
validation L:0.779955762336969


 34%|███▍      | 69/200 [01:35<02:54,  1.33s/it]

mean: 0.6431724290954164, std: 0.03794557136403669
training L: 0.7889707971350389
validation L:0.7800488102501525


 35%|███▌      | 70/200 [01:36<02:53,  1.33s/it]

mean: 0.6438113379585376, std: 0.037867618304883076
training L: 0.7890377369774535
validation L:0.7801158889905458


 36%|███▌      | 71/200 [01:38<02:52,  1.34s/it]

mean: 0.6443665282661462, std: 0.03755890204180918
training L: 0.7890558421170666
validation L:0.7800488102501525


 36%|███▌      | 72/200 [01:39<02:51,  1.34s/it]

mean: 0.6449561462094501, std: 0.037872248734854085
training L: 0.789103033746505
validation L:0.7800228745711018


 36%|███▋      | 73/200 [01:40<02:49,  1.33s/it]

mean: 0.6453280841816352, std: 0.03770579676780109
training L: 0.789079437226201
validation L:0.7802348635046515


 37%|███▋      | 74/200 [01:42<02:47,  1.33s/it]

mean: 0.6457358736092284, std: 0.03751679354920816
training L: 0.7890188103711235
validation L:0.7802088892277198


 38%|███▊      | 75/200 [01:43<02:46,  1.33s/it]

mean: 0.646295564651153, std: 0.037746101125036374
training L: 0.789050352828609
validation L:0.7800228745711018


 38%|███▊      | 76/200 [01:44<02:45,  1.33s/it]

mean: 0.6462463095208183, std: 0.03750683690394778
training L: 0.7890676394952455
validation L:0.7802088892277198


 38%|███▊      | 77/200 [01:46<02:43,  1.33s/it]

mean: 0.6468440748456522, std: 0.0377434471441708
training L: 0.789055023923445
validation L:0.7800823547353973


 39%|███▉      | 78/200 [01:47<02:40,  1.32s/it]

mean: 0.6470649900498601, std: 0.03791573178697559
training L: 0.7891101542877647
validation L:0.7802348635046515


 40%|███▉      | 79/200 [01:48<02:40,  1.32s/it]

mean: 0.647273252066047, std: 0.03790480812766759
training L: 0.7889290947695804
validation L:0.7801158889905458


 40%|████      | 80/200 [01:50<02:38,  1.32s/it]

mean: 0.6475985561213322, std: 0.03768843110730487
training L: 0.788983418310133
validation L:0.7802348635046515


 40%|████      | 81/200 [01:51<02:37,  1.32s/it]

mean: 0.6478119237514677, std: 0.03795452502134771
training L: 0.789133749962623
validation L:0.7802683745044221


 41%|████      | 82/200 [01:52<02:36,  1.33s/it]

mean: 0.6477440890117648, std: 0.03784521988385853
training L: 0.788947203157942
validation L:0.7801158889905458


 42%|████▏     | 83/200 [01:54<02:35,  1.33s/it]

mean: 0.6483128562010473, std: 0.037709612182815766
training L: 0.7890078344596615
validation L:0.7800228745711018


 42%|████▏     | 84/200 [01:55<02:35,  1.34s/it]

mean: 0.6486097237679442, std: 0.03754975698081344
training L: 0.7890920507721979
validation L:0.7801753717117804


 42%|████▎     | 85/200 [01:56<02:33,  1.34s/it]

mean: 0.6489740447107436, std: 0.03753976711385344
training L: 0.7890676394952455
validation L:0.7802088892277198


 43%|████▎     | 86/200 [01:58<02:32,  1.34s/it]

mean: 0.6488574469954199, std: 0.03766541645707203
training L: 0.7890613318979696
validation L:0.7800228745711018


 44%|████▎     | 87/200 [01:59<02:30,  1.34s/it]

mean: 0.6489096932871466, std: 0.037728392982625425
training L: 0.789103033746505
validation L:0.779929845966143


 44%|████▍     | 88/200 [02:00<02:29,  1.34s/it]

mean: 0.649378472614487, std: 0.037616527532277295
training L: 0.7890621495955866
validation L:0.7801158889905458


 44%|████▍     | 89/200 [02:02<02:28,  1.33s/it]

mean: 0.6495942640500949, std: 0.03741697010708341
training L: 0.7890983569794735
validation L:0.7802088892277198


 45%|████▌     | 90/200 [02:03<02:26,  1.34s/it]

mean: 0.6495154665553903, std: 0.03751850003918505
training L: 0.789085744187785
validation L:0.7801158889905458


 46%|████▌     | 91/200 [02:04<02:25,  1.33s/it]

mean: 0.6498210439550074, std: 0.03762015187425466
training L: 0.789049534261322
validation L:0.7800823547353973


 46%|████▌     | 92/200 [02:06<02:23,  1.33s/it]

mean: 0.6500737481281029, std: 0.03747188069033617
training L: 0.7890204520990313
validation L:0.7801158889905458


 46%|████▋     | 93/200 [02:07<02:22,  1.33s/it]

mean: 0.6497623566221192, std: 0.037333214636729145
training L: 0.7889960379756298
validation L:0.7802088892277198


 47%|████▋     | 94/200 [02:08<02:21,  1.33s/it]

mean: 0.6497072156138498, std: 0.03765919770957377
training L: 0.7891211387219281
validation L:0.7802088892277198


 48%|████▊     | 95/200 [02:10<02:20,  1.34s/it]

mean: 0.6498812003882101, std: 0.03746079489707398
training L: 0.7889661359049114
validation L:0.7799893235720278


 48%|████▊     | 96/200 [02:11<02:19,  1.34s/it]

mean: 0.6499510457127998, std: 0.03757139404685042
training L: 0.7890259400463482
validation L:0.7801418439716312


 48%|████▊     | 97/200 [02:12<02:17,  1.34s/it]

mean: 0.6504438735835141, std: 0.03743523975332393
training L: 0.789085744187785
validation L:0.7804468847708381


 49%|████▉     | 98/200 [02:14<02:16,  1.34s/it]

mean: 0.649769088231089, std: 0.037543300462062286
training L: 0.7890440450916485
validation L:0.7802088892277198


 50%|████▉     | 99/200 [02:15<02:14,  1.34s/it]

mean: 0.6500536357439018, std: 0.03735566876637967
training L: 0.789133749962623
validation L:0.7802088892277198


 50%|█████     | 100/200 [02:16<02:13,  1.34s/it]

mean: 0.6501871478072954, std: 0.03762670929697387
training L: 0.7890314284859903
validation L:0.780327868852459


 50%|█████     | 101/200 [02:18<02:12,  1.34s/it]

mean: 0.6501271829508782, std: 0.03714695929733477
training L: 0.7891046628096455
validation L:0.7801418439716312


 51%|█████     | 102/200 [02:19<02:10,  1.33s/it]

mean: 0.6500918652320735, std: 0.03742554390392043
training L: 0.7890322484189752
validation L:0.7802088892277198


 52%|█████▏    | 103/200 [02:20<02:09,  1.33s/it]

mean: 0.6501412497418192, std: 0.03739204724230821
training L: 0.789122765054117
validation L:0.7801753717117804


 52%|█████▏    | 104/200 [02:22<02:08,  1.34s/it]

mean: 0.6497717523605562, std: 0.03760921137923994
training L: 0.7890920507721979
validation L:0.7801753717117804


 52%|█████▎    | 105/200 [02:23<02:07,  1.34s/it]

mean: 0.6501580367727385, std: 0.037468959594168946
training L: 0.7891219519488092
validation L:0.7802088892277198


 53%|█████▎    | 106/200 [02:24<02:05,  1.34s/it]

mean: 0.6498953478366677, std: 0.037546487407533365
training L: 0.7890314284859903
validation L:0.7801158889905458


 54%|█████▎    | 107/200 [02:26<02:04,  1.34s/it]

mean: 0.6498872552868251, std: 0.037587378927069
training L: 0.7890314284859903
validation L:0.7801418439716312


 54%|█████▍    | 108/200 [02:27<02:03,  1.34s/it]

mean: 0.6500029294532585, std: 0.037289182435219455
training L: 0.7891400550173424
validation L:0.7801158889905458


 55%|█████▍    | 109/200 [02:28<02:02,  1.35s/it]

mean: 0.6502461237072283, std: 0.03755941936064975
training L: 0.7891282572620312
validation L:0.779929845966143


 55%|█████▌    | 110/200 [02:30<02:00,  1.34s/it]

mean: 0.6501953252623862, std: 0.037394467274640464
training L: 0.7890865600239199
validation L:0.7802683745044221


 56%|█████▌    | 111/200 [02:31<01:59,  1.34s/it]

mean: 0.6498664065940516, std: 0.03725298950261908
training L: 0.7891463596950217
validation L:0.779929845966143


 56%|█████▌    | 112/200 [02:32<01:57,  1.34s/it]

mean: 0.6499946135774257, std: 0.037446181779797444
training L: 0.7890660047253043
validation L:0.780327868852459


 56%|█████▋    | 113/200 [02:34<01:56,  1.34s/it]

mean: 0.6500599020048369, std: 0.037402034475766774
training L: 0.7890204520990313
validation L:0.7801158889905458


 57%|█████▋    | 114/200 [02:35<01:54,  1.34s/it]

mean: 0.6504418660385398, std: 0.03726827419105667
training L: 0.7890739467153066
validation L:0.7802088892277198


 57%|█████▊    | 115/200 [02:36<01:53,  1.34s/it]

mean: 0.6503700653919925, std: 0.03724308902750212
training L: 0.7891880578271465
validation L:0.7802088892277198


 58%|█████▊    | 116/200 [02:38<01:52,  1.34s/it]

mean: 0.6501551283243268, std: 0.03723439041383075
training L: 0.7890684566969158
validation L:0.7801753717117804


 58%|█████▊    | 117/200 [02:39<01:51,  1.34s/it]

mean: 0.6497153058598419, std: 0.037271555619692566
training L: 0.7889960379756298
validation L:0.7802348635046515


 59%|█████▉    | 118/200 [02:40<01:49,  1.34s/it]

mean: 0.6501052327897997, std: 0.03726821926702292
training L: 0.789134562198203
validation L:0.7802683745044221


 60%|█████▉    | 119/200 [02:42<01:48,  1.34s/it]

mean: 0.6503252694044814, std: 0.03731639904103009
training L: 0.7890613318979696
validation L:0.7801158889905458


 60%|██████    | 120/200 [02:43<01:46,  1.34s/it]

mean: 0.6500738723572284, std: 0.037279360225059475
training L: 0.7890621495955866
validation L:0.779929845966143


 60%|██████    | 121/200 [02:44<01:45,  1.33s/it]

mean: 0.6498417167752963, std: 0.037079801938013406
training L: 0.7891046628096455
validation L:0.7802088892277198


 61%|██████    | 122/200 [02:46<01:44,  1.33s/it]

mean: 0.650265070274585, std: 0.037034828333994324
training L: 0.7890802535581868
validation L:0.7802683745044221


 62%|██████▏   | 123/200 [02:47<01:42,  1.33s/it]

mean: 0.6502467146360357, std: 0.037436068434452505
training L: 0.7891164598594708
validation L:0.7802088892277198


 62%|██████▏   | 124/200 [02:48<01:41,  1.33s/it]

mean: 0.6496430102933335, std: 0.03721716628517319
training L: 0.7890676394952455
validation L:0.7802088892277198


 62%|██████▎   | 125/200 [02:50<01:40,  1.33s/it]

mean: 0.6499345323683191, std: 0.0374692464488565
training L: 0.789050352828609
validation L:0.7802088892277198


 63%|██████▎   | 126/200 [02:51<01:38,  1.33s/it]

mean: 0.6495119998868814, std: 0.037422044042390024
training L: 0.7891573470486215
validation L:0.7801158889905458


 64%|██████▎   | 127/200 [02:52<01:37,  1.33s/it]

mean: 0.6494766795521983, std: 0.037273700075502667
training L: 0.7890865600239199
validation L:0.7801158889905458


 64%|██████▍   | 128/200 [02:54<01:39,  1.39s/it]

mean: 0.6498283857833278, std: 0.037200995213504846
training L: 0.7890865600239199
validation L:0.7802683745044221


 64%|██████▍   | 129/200 [02:55<01:37,  1.37s/it]

mean: 0.6494224017071586, std: 0.037259201298920665
training L: 0.789139243155959
validation L:0.7802088892277198


 65%|██████▌   | 130/200 [02:57<01:35,  1.36s/it]

mean: 0.6491423779067859, std: 0.03721442709299281
training L: 0.7891219519488092
validation L:0.7800228745711018


 66%|██████▌   | 131/200 [02:58<01:33,  1.36s/it]

mean: 0.6499734358614128, std: 0.036943805382440485
training L: 0.7891526639956945
validation L:0.7800228745711018


 66%|██████▌   | 132/200 [02:59<01:35,  1.40s/it]

mean: 0.6500218696465765, std: 0.037401049570965614
training L: 0.789133749962623
validation L:0.7802348635046515


 66%|██████▋   | 133/200 [03:01<01:32,  1.38s/it]

mean: 0.6494513582998024, std: 0.037263276859249905
training L: 0.7890739467153066
validation L:0.780327868852459


 67%|██████▋   | 134/200 [03:02<01:30,  1.37s/it]

mean: 0.6494726131571438, std: 0.03723299795554167
training L: 0.789002347242364
validation L:0.779929845966143


 68%|██████▊   | 135/200 [03:03<01:28,  1.36s/it]

mean: 0.6495496750291943, std: 0.03708250920514626
training L: 0.7891046628096455
validation L:0.7802683745044221


 68%|██████▊   | 136/200 [03:05<01:26,  1.35s/it]

mean: 0.6494956205671558, std: 0.0371957410045398
training L: 0.789085744187785
validation L:0.7802683745044221


 68%|██████▊   | 137/200 [03:06<01:24,  1.33s/it]

mean: 0.650535865233708, std: 0.037154148136127794
training L: 0.7890802535581868
validation L:0.7799893235720278


 69%|██████▉   | 138/200 [03:07<01:22,  1.33s/it]

mean: 0.6495115698645602, std: 0.037151209592330675
training L: 0.7890322484189752
validation L:0.7802088892277198


 70%|██████▉   | 139/200 [03:09<01:21,  1.33s/it]

mean: 0.6493358151127034, std: 0.03723563590252677
training L: 0.7890377369774535
validation L:0.7802088892277198
mean: 0.6496482683585386, std: 0.03737806580611282
training L: 0.7889417182500523
validation L:0.7802683745044221


 70%|███████   | 141/200 [03:12<01:21,  1.38s/it]

mean: 0.6503415530453325, std: 0.03719735407667015
training L: 0.7891046628096455
validation L:0.7801158889905458


 71%|███████   | 142/200 [03:13<01:22,  1.42s/it]

mean: 0.6505178426903855, std: 0.037125989000501965
training L: 0.789133749962623
validation L:0.7802683745044221


 72%|███████▏  | 143/200 [03:14<01:19,  1.39s/it]

mean: 0.6491846292292962, std: 0.03756563606990346
training L: 0.7890377369774535
validation L:0.7800823547353973


 72%|███████▏  | 144/200 [03:16<01:17,  1.38s/it]

mean: 0.6489230551477677, std: 0.03713717678777013
training L: 0.7891463596950217
validation L:0.780327868852459


 72%|███████▎  | 145/200 [03:17<01:14,  1.36s/it]

mean: 0.6500091637951761, std: 0.03723965317236613
training L: 0.7890920507721979
validation L:0.780327868852459


 73%|███████▎  | 146/200 [03:18<01:13,  1.36s/it]

mean: 0.6505028659865627, std: 0.03726119110477237
training L: 0.7891282572620312
validation L:0.7801753717117804


 74%|███████▎  | 147/200 [03:20<01:11,  1.35s/it]

mean: 0.6508349413277549, std: 0.03710587107158111
training L: 0.7890322484189752
validation L:0.7802683745044221


 74%|███████▍  | 148/200 [03:21<01:09,  1.34s/it]

mean: 0.6497215027685552, std: 0.03727335028403669
training L: 0.789050352828609
validation L:0.7800228745711018


 74%|███████▍  | 149/200 [03:23<01:11,  1.40s/it]

mean: 0.6484142505097515, std: 0.03713637206976326
training L: 0.7890975420130375
validation L:0.7801158889905458


 75%|███████▌  | 150/200 [03:24<01:09,  1.38s/it]

mean: 0.6485851457429599, std: 0.03712819891392657
training L: 0.7891101542877647
validation L:0.7801158889905458


 76%|███████▌  | 151/200 [03:25<01:07,  1.37s/it]

mean: 0.6499649553057815, std: 0.037177752815117135
training L: 0.7891518531254205
validation L:0.780327868852459


 76%|███████▌  | 152/200 [03:27<01:05,  1.36s/it]

mean: 0.6509560817307467, std: 0.03726366366839285
training L: 0.7891164598594708
validation L:0.7801753717117804


 76%|███████▋  | 153/200 [03:28<01:03,  1.36s/it]

mean: 0.6511630880661888, std: 0.03718854248276555
training L: 0.7890739467153066
validation L:0.7801753717117804


 77%|███████▋  | 154/200 [03:29<01:02,  1.35s/it]

mean: 0.6499539006502881, std: 0.03738341670309168
training L: 0.7890865600239199
validation L:0.7802088892277198


 78%|███████▊  | 155/200 [03:31<01:00,  1.35s/it]

mean: 0.6503003855274889, std: 0.03720285269586973
training L: 0.789050352828609
validation L:0.7801418439716312


 78%|███████▊  | 156/200 [03:32<00:59,  1.35s/it]

mean: 0.6498567126833039, std: 0.037093566403896376
training L: 0.789079437226201
validation L:0.7800228745711018


 78%|███████▊  | 157/200 [03:34<01:00,  1.40s/it]

mean: 0.6503346970339198, std: 0.037376419242272814
training L: 0.7890802535581868
validation L:0.7802088892277198


 79%|███████▉  | 158/200 [03:35<00:57,  1.38s/it]

mean: 0.651107239786559, std: 0.037211889725529586
training L: 0.7890865600239199
validation L:0.7801158889905458


 80%|███████▉  | 159/200 [03:36<00:58,  1.42s/it]

mean: 0.650475750492941, std: 0.03715005283271281
training L: 0.7890865600239199
validation L:0.7801753717117804


 80%|████████  | 160/200 [03:38<00:55,  1.39s/it]

mean: 0.6498578179262586, std: 0.037078632658453556
training L: 0.7890865600239199
validation L:0.7801753717117804


 80%|████████  | 161/200 [03:39<00:53,  1.38s/it]

mean: 0.6516970899085567, std: 0.03730244981036209
training L: 0.7890259400463482
validation L:0.7802088892277198


 81%|████████  | 162/200 [03:40<00:51,  1.36s/it]

mean: 0.6520823474398048, std: 0.037284116516145244
training L: 0.7890747634210881
validation L:0.7802088892277198


 82%|████████▏ | 163/200 [03:42<00:50,  1.36s/it]

mean: 0.6507470263261494, std: 0.03741363351940397
training L: 0.789056660188369
validation L:0.7802088892277198


 82%|████████▏ | 164/200 [03:43<00:48,  1.35s/it]

mean: 0.6504946189314396, std: 0.03720300507104151
training L: 0.7890621495955866
validation L:0.7800228745711018


 82%|████████▎ | 165/200 [03:45<00:48,  1.39s/it]

mean: 0.6502906632818455, std: 0.03728390134381776
training L: 0.7891219519488092
validation L:0.7802088892277198


 83%|████████▎ | 166/200 [03:46<00:46,  1.38s/it]

mean: 0.6499858704657006, std: 0.03723080418488043
training L: 0.7890440450916485
validation L:0.7802683745044221


 84%|████████▎ | 167/200 [03:47<00:45,  1.36s/it]

mean: 0.6504895816566616, std: 0.03741216036682273
training L: 0.789133749962623
validation L:0.7800228745711018


 84%|████████▍ | 168/200 [03:49<00:43,  1.36s/it]

mean: 0.6512761960166494, std: 0.03729500754168066
training L: 0.789134562198203
validation L:0.7802088892277198


 84%|████████▍ | 169/200 [03:50<00:41,  1.35s/it]

mean: 0.6518689578150552, std: 0.037488372087739356
training L: 0.7890802535581868
validation L:0.7801158889905458


 85%|████████▌ | 170/200 [03:51<00:40,  1.35s/it]

mean: 0.6519314242298094, std: 0.037489673191777144
training L: 0.7891101542877647
validation L:0.7801753717117804


 86%|████████▌ | 171/200 [03:53<00:39,  1.35s/it]

mean: 0.650693017283545, std: 0.03731097606676058
training L: 0.7890920507721979
validation L:0.780327868852459


 86%|████████▌ | 172/200 [03:54<00:37,  1.35s/it]

mean: 0.6510142397210906, std: 0.03743140358968756
training L: 0.789050352828609
validation L:0.7802088892277198


 86%|████████▋ | 173/200 [03:55<00:36,  1.35s/it]

mean: 0.6514361465637869, std: 0.03725046086016204
training L: 0.7890684566969158
validation L:0.780327868852459


 87%|████████▋ | 174/200 [03:57<00:34,  1.35s/it]

mean: 0.6513559322656731, std: 0.03723085487433993
training L: 0.7891581575445888
validation L:0.7802088892277198


 88%|████████▊ | 175/200 [03:58<00:33,  1.34s/it]

mean: 0.6509152845170848, std: 0.03743742924000529
training L: 0.7891101542877647
validation L:0.7802683745044221


 88%|████████▊ | 176/200 [03:59<00:32,  1.34s/it]

mean: 0.651684699399912, std: 0.03757594819922277
training L: 0.7890802535581868
validation L:0.7801158889905458


 88%|████████▊ | 177/200 [04:01<00:30,  1.34s/it]

mean: 0.6507555980330351, std: 0.03752838806628672
training L: 0.7891400550173424
validation L:0.7802088892277198


 89%|████████▉ | 178/200 [04:02<00:29,  1.34s/it]

mean: 0.6511253328075242, std: 0.03743421966639613
training L: 0.7891164598594708
validation L:0.780327868852459


 90%|████████▉ | 179/200 [04:03<00:28,  1.34s/it]

mean: 0.6521646833636132, std: 0.037542516726807904
training L: 0.7891046628096455
validation L:0.7801158889905458


 90%|█████████ | 180/200 [04:05<00:27,  1.39s/it]

mean: 0.6517703879675681, std: 0.03756854569544114
training L: 0.7891699557469202
validation L:0.7802683745044221


 90%|█████████ | 181/200 [04:06<00:26,  1.38s/it]

mean: 0.6502751314481255, std: 0.037648384066966877
training L: 0.789110968262748
validation L:0.780327868852459


 91%|█████████ | 182/200 [04:08<00:24,  1.37s/it]

mean: 0.6514952344314373, std: 0.03747867401004681
training L: 0.7891164598594708
validation L:0.7802088892277198


 92%|█████████▏| 183/200 [04:09<00:23,  1.36s/it]

mean: 0.6526221874426001, std: 0.037801697534765845
training L: 0.7891164598594708
validation L:0.7802088892277198


 92%|█████████▏| 184/200 [04:10<00:21,  1.36s/it]

mean: 0.6525767401123828, std: 0.0376094175963354
training L: 0.7890920507721979
validation L:0.7802088892277198


 92%|█████████▎| 185/200 [04:12<00:21,  1.40s/it]

mean: 0.6525164012595259, std: 0.037638741217422444
training L: 0.789134562198203
validation L:0.7802088892277198


 93%|█████████▎| 186/200 [04:13<00:19,  1.38s/it]

mean: 0.6523958507033527, std: 0.03754099197440361
training L: 0.7890684566969158
validation L:0.7801158889905458


 94%|█████████▎| 187/200 [04:14<00:17,  1.37s/it]

mean: 0.6523987897332388, std: 0.037702384989835296
training L: 0.7891282572620312
validation L:0.7800228745711018


 94%|█████████▍| 188/200 [04:16<00:16,  1.36s/it]

mean: 0.6524044044291957, std: 0.03780756563922846
training L: 0.7891282572620312
validation L:0.7802088892277198


 94%|█████████▍| 189/200 [04:17<00:15,  1.36s/it]

mean: 0.651952320904239, std: 0.037778805295195735
training L: 0.7891219519488092
validation L:0.7802088892277198


 95%|█████████▌| 190/200 [04:18<00:13,  1.36s/it]

mean: 0.6520589789247596, std: 0.03766379883222292
training L: 0.7891046628096455
validation L:0.7801753717117804


 96%|█████████▌| 191/200 [04:20<00:12,  1.35s/it]

mean: 0.6527499982519376, std: 0.03785650292542167
training L: 0.7891219519488092
validation L:0.7801753717117804


 96%|█████████▌| 192/200 [04:21<00:10,  1.35s/it]

mean: 0.6517534958450468, std: 0.037894126614088454
training L: 0.7891518531254205
validation L:0.7802348635046515


 96%|█████████▋| 193/200 [04:23<00:09,  1.40s/it]

mean: 0.6502472349816975, std: 0.03800010367484725
training L: 0.7891518531254205
validation L:0.7802088892277198


 97%|█████████▋| 194/200 [04:24<00:08,  1.38s/it]

mean: 0.6513283861060108, std: 0.03792280645361822
training L: 0.7891400550173424
validation L:0.7802683745044221


 98%|█████████▊| 195/200 [04:25<00:06,  1.37s/it]

mean: 0.6521683720373683, std: 0.037860077419014454
training L: 0.789122765054117
validation L:0.7802088892277198


 98%|█████████▊| 196/200 [04:27<00:05,  1.41s/it]

mean: 0.652412709201666, std: 0.03828000420982498
training L: 0.7891219519488092
validation L:0.7802683745044221


 98%|█████████▊| 197/200 [04:28<00:04,  1.39s/it]

mean: 0.6528608022047981, std: 0.038130725836294684
training L: 0.7890975420130375
validation L:0.7799893235720278


 99%|█████████▉| 198/200 [04:30<00:02,  1.38s/it]

mean: 0.6532184313266275, std: 0.03817823601465069
training L: 0.7891400550173424
validation L:0.7801753717117804


100%|█████████▉| 199/200 [04:31<00:01,  1.37s/it]

mean: 0.652308709116172, std: 0.038421576275752184
training L: 0.7891518531254205
validation L:0.7801753717117804


100%|██████████| 200/200 [04:32<00:00,  1.36s/it]


mean: 0.6524606512561358, std: 0.038428353510256905
training L: 0.7890322484189752
validation L:0.7802348635046515


'time=2023-04-04 20:54:46.715070_epoch=196.pt'

# Training Loop - For sentence level feature Only model (This is actually for word level only)

In [14]:
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/word_config.json", "r"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# obtain the dataset
torch.set_default_tensor_type(torch.DoubleTensor)
training_dataset = Aversion_SelfTap111(dataset_location, training_set, word_timing=True)
validation_dataset = Aversion_SelfTap111(dataset_location, testing_set, word_timing=True)
train_dataloader = torch.utils.data.DataLoader(training_dataset, config['batch_size'], True)
valid_dataloader = torch.utils.data.DataLoader(validation_dataset, config['batch_size'], True)
model = SentenceBaseline_GazePredictionModel(config)
model.to(device)
train_model(model, config, train_dataloader, valid_dataloader, run_obj, model_save_location)

NameError: name 'A' is not defined

## further train the model

In [1]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/word_config.json", "r"))
model = SentenceBaseline_GazePredictionModel(config)
config["learning_rate"] = 0.00001
config["load_model"] = True
if config["wandb"]:
    wandb.login()
    if config["load_model"]:
        run_obj = wandb.init(project="gaze_prediction", config=config, save_code=True,
            resume='allow', id='8w9fyxan')
        # checkpoint_name = "gaze_prediction_team/gaze_prediction/8w9fyxan"
        checkpoint_path = "/scratch/ondemand27/evanpan/data/Gaze_aversion_models/time=2023-04-05 02:37:34.205141_epoch=200.pt"
        wandb.restore(checkpoint_path)
        pretrained_dict = torch.load(checkpoint_path, map_location=device)
        model.load_weights(pretrained_dict)
        model.to(config['device'])
    else:
        run_obj = wandb.init(project="gaze_prediction", config=config, settings=wandb.Settings(start_method="fork"))
else:
    run_obj = None

training_dataset = Aversion_SelfTap111(dataset_location, training_set, sentence_timing=True)
validation_dataset = Aversion_SelfTap111(dataset_location, testing_set, sentence_timing=True)
train_dataloader = torch.utils.data.DataLoader(training_dataset, config['batch_size'], True)
valid_dataloader = torch.utils.data.DataLoader(validation_dataset, config['batch_size'], True)
train_model(model, config, train_dataloader, valid_dataloader, run_obj, model_save_location)

NameError: name 'torch' is not defined

# Training Loop - For sentence level feature Only model (It's real this time)

In [11]:
run_obj = wandb.init(project="gaze_prediction", config=config, settings=wandb.Settings(start_method="fork"))
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/sentence_config.json", "r"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# obtain the dataset
torch.set_default_tensor_type(torch.DoubleTensor)
training_dataset = Aversion_SelfTap111(dataset_location, training_set[0:10], sentence_and_word_timing=True)
validation_dataset = Aversion_SelfTap111(dataset_location, testing_set[0:2], sentence_and_word_timing=True)
train_dataloader = torch.utils.data.DataLoader(training_dataset, config['batch_size'], True)
valid_dataloader = torch.utils.data.DataLoader(validation_dataset, config['batch_size'], True)
model = SentenceBaseline_GazePredictionModel(config)
model.to(device)
train_model(model, config, train_dataloader, valid_dataloader, run_obj, "sentence_and_words")
run_obj.finish()


Epoch 1, mean: 0.17789314687252045, std: 0.3824285864830017
training L: 0.38361384790041525
validation L:0.5202852326801098


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 2, mean: 0.5589313507080078, std: 0.49652254581451416
training L: 0.5135592779060695
validation L:0.5422346910798534
Epoch 3, mean: 0.8715725541114807, std: 0.3345702886581421
training L: 0.5083247797239047
validation L:0.5124457632512047, model have not improved for 1 iterations
Epoch 4, mean: 0.9777710437774658, std: 0.14742977917194366
training L: 0.4620488325757937
validation L:0.5003760834120156, model have not improved for 2 iterations
Epoch 5, mean: 0.9966717958450317, std: 0.05759573355317116
training L: 0.45366407573439577
validation L:0.49920991627236866
Epoch 6, mean: 0.9993893504142761, std: 0.024704914540052414
training L: 0.45135263240753554
validation L:0.49904111077228025
Epoch 7, mean: 0.999908447265625, std: 0.009570656344294548
training L: 0.4510088849744354
validation L:0.4989816254955779, model have not improved for 1 iterations
Epoch 8, mean: 0.9999695420265198, std: 0.005525789689272642
training L: 0.4510375310310284
validation L:0.49904111077228025, model 





Stopping early due to decrease in performance on validation set




VBox(children=(Label(value='2.544 MB of 2.544 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
percentage_predicted_aversion,▁▄▇██████████████████████████████
training loss,█▅▃▂▂▂▂▂▃▃▃▂▂▂▂▂▁▁▁▂▂▂▂▂▂▁▁▁▁▁▁▁▁
training_f1,▁██▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
validation_f1,▄█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁
validation_loss,█▅▃▂▂▁▂▂▂▁▂▁▁▂▂▂▂▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
percentage_predicted_aversion,0.99969
training loss,0.67176
training_f1,0.45141
validation_f1,0.49892
validation_loss,0.65346


In [18]:
training_dataset = Aversion_SelfTap111(dataset_location, training_set, sentence_and_word_timing=True)
for i in range(0, len(training_dataset)):
    print(training_dataset[i][0].shape)

(250, 6)
(240, 6)
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/scratch/ondemand27/evanpan/conda_env/.conda/envs/jaligaze/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_17584/808632026.py", line 3, in <module>
    print(training_dataset[i][0].shape)
  File "/scratch/ondemand27/evanpan/Gaze_project/Dataset_Util/dataloader.py", line 208, in __getitem__
    input_text_on_screen = np.concatenate([input_text_on_screen1, input_text_on_screen2], axis = 1)
  File "<__array_function__ internals>", line 180, in concatenate
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 250 and the array at index 1 has size 240

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/scratch/ondemand27/evanpan/conda_env/.conda/envs/jaligaze/lib/python3.10/site-package

# Look at the output of the model

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = json.load(open("/scratch/ondemand27/evanpan/Gaze_project/training/sentence_config.json", "r"))
# obtain the dataset
torch.set_default_tensor_type(torch.DoubleTensor)
validation_dataset = Aversion_SelfTap111(dataset_location, testing_set, word_timing=True)
valid_dataloader = torch.utils.data.DataLoader(validation_dataset, config['batch_size'], True)
model = SentenceBaseline_GazePredictionModel(config)
checkpoint_path = "/scratch/ondemand27/evanpan/data/Gaze_aversion_models/time=2023-04-05 02:37:34.205141_epoch=200.pt"
pretrained_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(pretrained_dict)
# train_model(model, config, train_dataloader, valid_dataloader, run_obj, model_save_location)


/usr/bin/nvidia-modprobe: unrecognized option: "-s"

ERROR: Invalid commandline, please run `/usr/bin/nvidia-modprobe --help`
       for usage information.

/usr/bin/nvidia-modprobe: unrecognized option: "-s"

ERROR: Invalid commandline, please run `/usr/bin/nvidia-modprobe --help`
       for usage information.



<All keys matched successfully>

In [10]:
for _, (X, Y) in enumerate(valid_data):
    with torch.no_grad():
        valid_batch_counter += 1
        X, Y = X.to(device), Y.to(device)
        if "Transformer" in config["model_type"]:
            all_zero = torch.zeros(Y.shape).to(device)
            pred = model(X, all_zero)
        else:
            pred = model(X)
        # binary_pred = torch.round(pred)
        binary_pred = torch.argmax(pred, axis=2, keepdim=True)
        f1_valid = f1_score(binary_pred, torch.unsqueeze(Y, axis=2)).item()
        total_valid_f1 += f1_valid
        del X, Y, pred
        torch.cuda.empty_cache()

NameError: name 'valid_data' is not defined

In [32]:
model.to(device)
i = 11
x = torch.from_numpy(validation_dataset[i][0]).to(device)
y = validation_dataset[i][1]
pred = model(torch.unsqueeze(x, axis=0))
pred = torch.softmax(pred, dim=2)
pred = pred.cpu().detach().numpy()
plt.plot(pred[0, :, 0], label="prediction")
plt.plot(y, label="label")
plt.legend()


IndexError: index 2 is out of bounds for dimension 2 with size 2