In [2]:
import math
import os
from methods import Segmenter, TransformerModel, LSTM, CRNN, UNet,CCRNN, UNet2, PatchTST
from matplotlib import pyplot as plt
import torch
from utilities import printc, seed
from utils_loader import get_dataloaders, test_idea_dataloader_er_ir, test_idea_dataloader_burpee_pushup
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from utils_metrics import mean_iou
from torch.nn import functional as F
from train_dense_labeling import get_model
from methods.unet import UNet_encoder
from sklearn.metrics import f1_score 
from transformers import PatchTSTConfig, PatchTSTForClassification
config = {
        'batch_size': 128,
        'epochs':200,
        'fsl': False,
        'model': 'unet',
        'seed': 73054772,
        'dataset': 'physiq'
    }
    
seed(config['seed'])
train_loader, val_loader, test_loader = test_idea_dataloader_er_ir(config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
patch_config = PatchTSTConfig(
    num_attention_heads=1,
    num_input_channels=6,
    num_targets=2,
    context_length=50,
    patch_length=10,
    stride=10,
    use_cls_token=True,
)
model = PatchTSTForClassification(config=patch_config).float().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, verbose=True)

for epoch in range(50):
    model.train()
    for batch in tqdm(train_loader):
        x, y = batch
        x, y = x.to(device).float(), y.to(device).long()
        optimizer.zero_grad()
        y_hat = model(past_values=x)
        y_hat = y_hat.prediction_logits
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        val_loss = 0
        accs = 0

        for batch in val_loader:
            x, y = batch
            x, y = x.to(device).float(), y.to(device).long()
            y_hat = model(past_values=x)
            y_hat = y_hat.prediction_logits
            accs += (y_hat.argmax(1) == y).float().mean()
            val_loss += F.cross_entropy(y_hat, y)
        val_loss /= len(val_loader)
    print(f'Epoch {epoch}, val_loss: {val_loss}')
    print(f'Epoch {epoch}, val_acc: {accs/len(val_loader)}')
    
    # scheduler.step(val_loss)


100%|██████████| 28/28 [00:00<00:00, 52.09it/s]


Epoch 0, val_loss: 0.6717113852500916
Epoch 0, val_acc: 0.6029942631721497


100%|██████████| 28/28 [00:00<00:00, 74.12it/s]


Epoch 1, val_loss: 0.6676727533340454
Epoch 1, val_acc: 0.5817888975143433


100%|██████████| 28/28 [00:00<00:00, 76.32it/s]


Epoch 2, val_loss: 0.7362098097801208
Epoch 2, val_acc: 0.5311034917831421


100%|██████████| 28/28 [00:00<00:00, 80.26it/s]


Epoch 3, val_loss: 0.7079066634178162
Epoch 3, val_acc: 0.5383974313735962


100%|██████████| 28/28 [00:00<00:00, 81.99it/s]


Epoch 4, val_loss: 0.6385185122489929
Epoch 4, val_acc: 0.5970869064331055


100%|██████████| 28/28 [00:00<00:00, 77.44it/s]


Epoch 5, val_loss: 0.6420630812644958
Epoch 5, val_acc: 0.6262964606285095


100%|██████████| 28/28 [00:00<00:00, 75.75it/s]


Epoch 6, val_loss: 0.6404599547386169
Epoch 6, val_acc: 0.6259695887565613


100%|██████████| 28/28 [00:00<00:00, 75.55it/s]


Epoch 7, val_loss: 0.7183954119682312
Epoch 7, val_acc: 0.5695910453796387


100%|██████████| 28/28 [00:00<00:00, 74.74it/s]


Epoch 8, val_loss: 0.6282199025154114
Epoch 8, val_acc: 0.6556412577629089


100%|██████████| 28/28 [00:00<00:00, 64.80it/s]


Epoch 9, val_loss: 0.6175586581230164
Epoch 9, val_acc: 0.6467126607894897


100%|██████████| 28/28 [00:00<00:00, 78.72it/s]


Epoch 10, val_loss: 0.9043963551521301
Epoch 10, val_acc: 0.6077854633331299


100%|██████████| 28/28 [00:00<00:00, 75.93it/s]


Epoch 11, val_loss: 0.6186016798019409
Epoch 11, val_acc: 0.6685718894004822


100%|██████████| 28/28 [00:00<00:00, 78.92it/s]


Epoch 12, val_loss: 0.6566580533981323
Epoch 12, val_acc: 0.5972222089767456


100%|██████████| 28/28 [00:00<00:00, 80.66it/s]


Epoch 13, val_loss: 0.6340305805206299
Epoch 13, val_acc: 0.6139633059501648


100%|██████████| 28/28 [00:00<00:00, 78.60it/s]


Epoch 14, val_loss: 0.6315634846687317
Epoch 14, val_acc: 0.6191378831863403


100%|██████████| 28/28 [00:00<00:00, 77.09it/s]


Epoch 15, val_loss: 0.650328516960144
Epoch 15, val_acc: 0.5976845026016235


100%|██████████| 28/28 [00:00<00:00, 80.43it/s]


Epoch 16, val_loss: 0.6174963712692261
Epoch 16, val_acc: 0.661413311958313


100%|██████████| 28/28 [00:00<00:00, 79.35it/s]


Epoch 17, val_loss: 0.6208751201629639
Epoch 17, val_acc: 0.6378405094146729


100%|██████████| 28/28 [00:00<00:00, 82.08it/s]


Epoch 18, val_loss: 0.7120439410209656
Epoch 18, val_acc: 0.6200059056282043


100%|██████████| 28/28 [00:00<00:00, 78.33it/s]


Epoch 19, val_loss: 0.6144059896469116
Epoch 19, val_acc: 0.6601055264472961


100%|██████████| 28/28 [00:00<00:00, 81.51it/s]


Epoch 20, val_loss: 0.6353275775909424
Epoch 20, val_acc: 0.6320684552192688


100%|██████████| 28/28 [00:00<00:00, 80.44it/s]


Epoch 21, val_loss: 0.6778340935707092
Epoch 21, val_acc: 0.6016864776611328


100%|██████████| 28/28 [00:00<00:00, 78.71it/s]


Epoch 22, val_loss: 0.6181671023368835
Epoch 22, val_acc: 0.6500608921051025


100%|██████████| 28/28 [00:00<00:00, 82.87it/s]


Epoch 23, val_loss: 0.6270979642868042
Epoch 23, val_acc: 0.6627209782600403


100%|██████████| 28/28 [00:00<00:00, 79.76it/s]


Epoch 24, val_loss: 0.6257294416427612
Epoch 24, val_acc: 0.6664750576019287


100%|██████████| 28/28 [00:00<00:00, 82.95it/s]


Epoch 25, val_loss: 0.7262900471687317
Epoch 25, val_acc: 0.6342442631721497


100%|██████████| 28/28 [00:00<00:00, 73.57it/s]


Epoch 26, val_loss: 0.657516360282898
Epoch 26, val_acc: 0.6022050976753235


100%|██████████| 28/28 [00:00<00:00, 78.28it/s]


Epoch 27, val_loss: 0.6232582926750183
Epoch 27, val_acc: 0.6607593894004822


100%|██████████| 28/28 [00:00<00:00, 79.05it/s]


Epoch 28, val_loss: 0.6383379101753235
Epoch 28, val_acc: 0.6635664701461792


100%|██████████| 28/28 [00:00<00:00, 82.71it/s]


Epoch 29, val_loss: 0.6410701870918274
Epoch 29, val_acc: 0.6393624544143677


100%|██████████| 28/28 [00:00<00:00, 79.88it/s]


Epoch 30, val_loss: 0.6186174154281616
Epoch 30, val_acc: 0.653600811958313


100%|██████████| 28/28 [00:00<00:00, 73.78it/s]


Epoch 31, val_loss: 0.6220186352729797
Epoch 31, val_acc: 0.6544463038444519


100%|██████████| 28/28 [00:00<00:00, 81.03it/s]


Epoch 32, val_loss: 0.6480771899223328
Epoch 32, val_acc: 0.6578733921051025


100%|██████████| 28/28 [00:00<00:00, 75.87it/s]


Epoch 33, val_loss: 0.6554809808731079
Epoch 33, val_acc: 0.6283933520317078


100%|██████████| 28/28 [00:00<00:00, 77.37it/s]


Epoch 34, val_loss: 0.7462990283966064
Epoch 34, val_acc: 0.6317415237426758


100%|██████████| 28/28 [00:00<00:00, 78.64it/s]


Epoch 35, val_loss: 0.6867462396621704
Epoch 35, val_acc: 0.6490012407302856


100%|██████████| 28/28 [00:00<00:00, 81.49it/s]


Epoch 36, val_loss: 0.6714732646942139
Epoch 36, val_acc: 0.5992627143859863


100%|██████████| 28/28 [00:00<00:00, 81.37it/s]


Epoch 37, val_loss: 0.6369431614875793
Epoch 37, val_acc: 0.6508500576019287


100%|██████████| 28/28 [00:00<00:00, 76.17it/s]


Epoch 38, val_loss: 0.6308726668357849
Epoch 38, val_acc: 0.6658775806427002


100%|██████████| 28/28 [00:00<00:00, 73.87it/s]


Epoch 39, val_loss: 0.661206066608429
Epoch 39, val_acc: 0.6547168493270874


100%|██████████| 28/28 [00:00<00:00, 79.53it/s]


Epoch 40, val_loss: 0.613369882106781
Epoch 40, val_acc: 0.6671852469444275


100%|██████████| 28/28 [00:00<00:00, 77.34it/s]


Epoch 41, val_loss: 0.6109002828598022
Epoch 41, val_acc: 0.6769030094146729


100%|██████████| 28/28 [00:00<00:00, 76.49it/s]


Epoch 42, val_loss: 0.6912501454353333
Epoch 42, val_acc: 0.6352250576019287


100%|██████████| 28/28 [00:00<00:00, 77.98it/s]


Epoch 43, val_loss: 0.6346443891525269
Epoch 43, val_acc: 0.6587188839912415


100%|██████████| 28/28 [00:00<00:00, 75.06it/s]


Epoch 44, val_loss: 0.6187400817871094
Epoch 44, val_acc: 0.6815589070320129


100%|██████████| 28/28 [00:00<00:00, 81.63it/s]


Epoch 45, val_loss: 0.6242417097091675
Epoch 45, val_acc: 0.6866206526756287


100%|██████████| 28/28 [00:00<00:00, 74.06it/s]


Epoch 46, val_loss: 0.6056951880455017
Epoch 46, val_acc: 0.6706687808036804


100%|██████████| 28/28 [00:00<00:00, 78.78it/s]


Epoch 47, val_loss: 0.6135711669921875
Epoch 47, val_acc: 0.6706124544143677


100%|██████████| 28/28 [00:00<00:00, 75.23it/s]


Epoch 48, val_loss: 0.6379120945930481
Epoch 48, val_acc: 0.6601619124412537


100%|██████████| 28/28 [00:00<00:00, 74.26it/s]


Epoch 49, val_loss: 0.6161254644393921
Epoch 49, val_acc: 0.673768937587738
