In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from torch import nn
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from lib.modules import (
    evaluate_loop, 
    read_and_window_session,
    read_session,
    train_loop,
    optimization_loop,
    predict_and_plot_pretty_session
)
from lib.utils import (
    plot_and_save_cm,
    summary
)
from lib.models import  MLP, MLP2hl
from tqdm import tqdm
import plotly.express as px
from tabulate import tabulate

In [10]:
raw_dir = Path("/home/musa/datasets/nursingv1")
label_dir = Path("/home/musa/datasets/eating_labels")
WINSIZE = 101
DEVICE = 'cuda:1'

In [11]:
train_sessions = [25, 67, 42]
test_sessions = [58, 62]

Xs = []
ys = []

for session_idx in train_sessions:
    X,y = read_and_window_session(session_idx, WINSIZE, raw_dir, label_dir)

    Xs.append(X)
    ys.append(y)

Xtr = torch.cat(Xs)
ytr = torch.cat(ys)

Xs = []
ys = []

for session_idx in test_sessions:
    X,y = read_and_window_session(session_idx, WINSIZE, raw_dir, label_dir)

    Xs.append(X)
    ys.append(y)

Xte = torch.cat(Xs)
yte = torch.cat(ys)

In [18]:
model = MLP2hl([20,20], WINSIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

In [19]:
trainloader = DataLoader(TensorDataset(Xtr, ytr), batch_size=64, shuffle=True)
testloader = DataLoader(TensorDataset(Xte,yte), batch_size=64)

optimization_loop(model, trainloader, testloader, criterion, optimizer, 50, DEVICE, Path('dev/mlp2hl'))

Epoch 49: Train Loss: 0.22026: Dev Loss: 0.85166: 100%|██████████| 50/50 [08:51<00:00, 10.63s/it]


In [15]:
# model.load_state_dict(torch.load(Path('dev/mlp2hl/best_model.pt')))
model.load_state_dict(torch.load(Path('dev/mlp2hl/best_model.pt')))

RuntimeError: Error(s) in loading state_dict for MLP2hl:
	Missing key(s) in state_dict: "h1.weight", "h1.bias", "h2.weight", "h2.bias", "out.weight", "out.bias". 
	Unexpected key(s) in state_dict: "resnet_conv.shortcut1.0.weight", "resnet_conv.shortcut1.0.bias", "resnet_conv.shortcut1.1.weight", "resnet_conv.shortcut1.1.bias", "resnet_conv.shortcut1.1.running_mean", "resnet_conv.shortcut1.1.running_var", "resnet_conv.shortcut1.1.num_batches_tracked", "resnet_conv.res1.0.0.weight", "resnet_conv.res1.0.0.bias", "resnet_conv.res1.0.1.weight", "resnet_conv.res1.0.1.bias", "resnet_conv.res1.0.1.running_mean", "resnet_conv.res1.0.1.running_var", "resnet_conv.res1.0.1.num_batches_tracked", "resnet_conv.res1.1.0.weight", "resnet_conv.res1.1.0.bias", "resnet_conv.res1.1.1.weight", "resnet_conv.res1.1.1.bias", "resnet_conv.res1.1.1.running_mean", "resnet_conv.res1.1.1.running_var", "resnet_conv.res1.1.1.num_batches_tracked", "resnet_conv.res1.2.0.weight", "resnet_conv.res1.2.0.bias", "resnet_conv.res1.2.1.weight", "resnet_conv.res1.2.1.bias", "resnet_conv.res1.2.1.running_mean", "resnet_conv.res1.2.1.running_var", "resnet_conv.res1.2.1.num_batches_tracked", "resnet_conv.shortcut3.0.weight", "resnet_conv.shortcut3.0.bias", "resnet_conv.shortcut3.1.weight", "resnet_conv.shortcut3.1.bias", "resnet_conv.shortcut3.1.running_mean", "resnet_conv.shortcut3.1.running_var", "resnet_conv.shortcut3.1.num_batches_tracked", "resnet_conv.res3.0.0.weight", "resnet_conv.res3.0.0.bias", "resnet_conv.res3.0.1.weight", "resnet_conv.res3.0.1.bias", "resnet_conv.res3.0.1.running_mean", "resnet_conv.res3.0.1.running_var", "resnet_conv.res3.0.1.num_batches_tracked", "resnet_conv.res3.1.0.weight", "resnet_conv.res3.1.0.bias", "resnet_conv.res3.1.1.weight", "resnet_conv.res3.1.1.bias", "resnet_conv.res3.1.1.running_mean", "resnet_conv.res3.1.1.running_var", "resnet_conv.res3.1.1.num_batches_tracked", "resnet_conv.res3.2.0.weight", "resnet_conv.res3.2.0.bias", "resnet_conv.res3.2.1.weight", "resnet_conv.res3.2.1.bias", "resnet_conv.res3.2.1.running_mean", "resnet_conv.res3.2.1.running_var", "resnet_conv.res3.2.1.num_batches_tracked", "output.weight", "output.bias". 

In [None]:
ys,metrics = evaluate_loop(model, criterion, trainloader, DEVICE)
plot_and_save_cm(ys['true'], ys['pred'])
summary(metrics)

In [None]:
ys,metrics = evaluate_loop(model, criterion, testloader, DEVICE)
plot_and_save_cm(ys['true'], ys['pred'])
summary(metrics)

In [None]:
test_session = test_sessions[0]

predict_and_plot_pretty_session(
    session_idx=test_session,
    dim_factor=5,
    datapath=raw_dir,
    labelpath=label_dir,
    winsize=WINSIZE,
    model=model,
    criterion=criterion,
    batch_size=64,
    device=DEVICE
)