In [1]:
import os
import glob
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

from eye_tracker_dataset import EyeTrackerDataset

In [2]:
class EyeTrackerNet(nn.Module) :
    def __init__(self) :
        super(EyeTrackerNet, self).__init__()

        self.fc1 = nn.Linear(81 * 3, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 16)
        self.fc5 = nn.Linear(16, 8)
        self.fc6 = nn.Linear(8, 2)

    def forward(self, x) :
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        x = F.relu(x)
        x = self.fc5(x)
        x = F.relu(x)
        x = self.fc6(x)
        x = F.relu(x)
        return x

model = EyeTrackerNet()


In [3]:
eye_tracker_dataset = EyeTrackerDataset()
print(len(eye_tracker_dataset))

torch_generator = torch.Generator().manual_seed(42)
train_dataset, test_dataset = random_split(eye_tracker_dataset, [0.7, 0.3], generator=torch_generator)
print(len(train_dataset), len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_loader  = DataLoader(test_dataset,  batch_size = 64, shuffle = True)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(
    params = model.parameters(),
    lr = 0.0001
)
y = model(eye_tracker_dataset[0][0].flatten())

3047
2133 914


In [4]:
def train_one_epoch(train_loader):
    loss_accumulated = 0
    n_data_accumulated = 0

    with tqdm(
        total = len(train_loader),
        desc = "train"
    ) as inner_pbar :
        for i, data in enumerate(train_loader):
            inputs, labels = data
            n_data = inputs.shape[0]

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            curr_loss  = loss.item()
            loss_accumulated += loss.item() * n_data
            n_data_accumulated += n_data
            whole_mean_loss = loss_accumulated / n_data_accumulated

            inner_pbar.update(1)
            inner_pbar.set_postfix_str(curr_loss)
            
    return whole_mean_loss

def valid_one_epoch(valid_loader) :
    loss_accumulated = 0
    n_data_accumulated = 0

    with tqdm(
        total = len(valid_loader),
        desc = "valid"
    ) as inner_pbar :
        for i, data in enumerate(valid_loader):
            inputs, labels = data
            n_data = inputs.shape[0]

            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

            curr_loss  = loss.item()
            loss_accumulated += loss.item() * n_data
            n_data_accumulated += n_data
            whole_mean_loss = loss_accumulated / n_data_accumulated

            inner_pbar.update(1)
            inner_pbar.set_postfix_str(curr_loss)
            
    return whole_mean_loss

In [5]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
N_MODEL_TO_SAVE = 3

model_dict_list = []
train_loss_list = []
valid_loss_list = []
train_valid_loop_cache = [{"valid_loss":100}]

curr_epoch_idx = 0

In [6]:
N_EPOCHS = 200

with tqdm(total=N_EPOCHS) as outer_pbar :
    for epoch in range(curr_epoch_idx, curr_epoch_idx + N_EPOCHS) :
        curr_epoch_idx = epoch
        model.train(True)
        whole_mean_loss_train = train_one_epoch(train_loader)
        train_loss_list.append(whole_mean_loss_train)

        model.train(False)        
        whole_mean_loss_valid = valid_one_epoch(test_loader)
        valid_loss_list.append(whole_mean_loss_valid)

        if whole_mean_loss_valid < train_valid_loop_cache[-1]["valid_loss"] :
            train_valid_loop_cache.append({
                "state_dict": model.state_dict().copy(),
                "train_loss": whole_mean_loss_train,
                "valid_loss": whole_mean_loss_valid,
                "epoch"     : epoch
            })
            train_valid_loop_cache.sort(key = lambda data: data["valid_loss"])
            if len(train_valid_loop_cache) > N_MODEL_TO_SAVE :
                train_valid_loop_cache.pop(-1)

        outer_pbar.update(1)
        outer_pbar.set_postfix_str(str(whole_mean_loss_valid))

        print("tloss : {:.5f} vloss : {:.5f}".format(whole_mean_loss_train, whole_mean_loss_valid))

for model_data in train_valid_loop_cache :
    model_path = os.path.join(
        "checkpoints",
        'model_{}_{}'.format(timestamp, model_data["epoch"])
    )
    torch.save(model_data["state_dict"], model_path)

plt.plot(train_loss_list, label="train_loss")
plt.plot(valid_loss_list, label="valid_loss")
plt.legend(loc="upper right")
plt.show()

  0%|          | 0/200 [00:00<?, ?it/s]

train:   0%|          | 0/34 [00:00<?, ?it/s]

valid:   0%|          | 0/15 [00:00<?, ?it/s]

tloss : 0.00398 vloss : 0.00432


train:   0%|          | 0/34 [00:00<?, ?it/s]

valid:   0%|          | 0/15 [00:00<?, ?it/s]

tloss : 0.00394 vloss : 0.00423


train:   0%|          | 0/34 [00:00<?, ?it/s]

valid:   0%|          | 0/15 [00:00<?, ?it/s]

tloss : 0.00390 vloss : 0.00419


train:   0%|          | 0/34 [00:00<?, ?it/s]

valid:   0%|          | 0/15 [00:00<?, ?it/s]

tloss : 0.00334 vloss : 0.00283


train:   0%|          | 0/34 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
epoch